From 16a05f6097760cc2e2e5582ff17e03660e067ed3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Mar 2026 15:20:04 +0000 Subject: [PATCH 01/13] Initial plan From f116b44fea003a24dc2b8bb3144864f85539b455 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Mar 2026 15:33:04 +0000 Subject: [PATCH 02/13] Implement HTTP relay proxy load balancer with health checking and hot reload Co-authored-by: MejiroRina <70424266+MejiroRina@users.noreply.github.com> Agent-Logs-Url: https://github.com/Team-Haruki/http-proxy-lb/sessions/f3748e80-1418-4394-a40c-a54bceb41f43 --- Cargo.lock | 607 +++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 28 ++ README.md | 106 +++++++- config.example.yaml | 38 +++ src/config.rs | 241 ++++++++++++++++ src/health.rs | 67 +++++ src/main.rs | 130 +++++++++ src/proxy.rs | 649 ++++++++++++++++++++++++++++++++++++++++++++ src/upstream.rs | 430 +++++++++++++++++++++++++++++ 9 files changed, 2295 insertions(+), 1 deletion(-) create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 config.example.yaml create mode 100644 src/config.rs create mode 100644 src/health.rs create mode 100644 src/main.rs create mode 100644 src/proxy.rs create mode 100644 src/upstream.rs diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..550d691 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,607 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anstream" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" + +[[package]] +name = "anstyle-parse" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "clap" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" + +[[package]] +name = "colorchoice" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "http-proxy-lb" +version = "0.1.0" +dependencies = [ + "anyhow", + "base64", + "clap", + "httparse", + "parking_lot", + "serde", + "tokio", + "tracing", + "tracing-subscriber", + "yaml_serde", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.183" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" + +[[package]] +name = "libyaml-rs" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e126dda6f34391ab7b444f9922055facc83c07a910da3eb16f1e4d9c45dc777" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mio" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +dependencies = [ + "libc", + "wasi", + "windows-sys", +] + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "tokio" +version = "1.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +dependencies = [ + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-macros" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "yaml_serde" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c7c1b1a6a7c8a6b2741a6c21a4f8918e51899b111cfa08d1288202656e3975" +dependencies = [ + "indexmap", + "itoa", + "libyaml-rs", + "ryu", + "serde", +] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..be30cf2 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "http-proxy-lb" +version = "0.1.0" +edition = "2021" +description = "HTTP relay proxy with upstream load balancing, health checking, and hot reload" +license = "MIT" + +[[bin]] +name = "http-proxy-lb" +path = "src/main.rs" + +[dependencies] +# Async runtime +tokio = { version = "1", features = ["full"] } +# HTTP/1.x request/response parsing +httparse = "1" +# Serialization +serde = { version = "1", features = ["derive"] } +yaml_serde = "0.10" +# CLI +clap = { version = "4", features = ["derive"] } +# Logging +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } +# Utilities +anyhow = "1" +base64 = "0.22" +parking_lot = "0.12" diff --git a/README.md b/README.md index 61d1f20..7feecac 100644 --- a/README.md +++ b/README.md @@ -1 +1,105 @@ -# http-proxy-lb \ No newline at end of file +# http-proxy-lb + +A high-availability HTTP relay proxy with upstream load balancing, health checking and hot config reload — written in Rust. + +## Features + +| Feature | Description | +|---|---| +| **Round-robin load balancing** | Weighted round-robin across all online upstream proxies | +| **Best-selection mode** | Always routes to the upstream with the lowest latency + connection score | +| **Passive health detection** | Any upstream that fails to connect or respond is immediately marked offline and excluded from routing | +| **Active health checking** | Background task periodically probes offline upstreams via TCP connect; marks them online once they recover | +| **Automatic failover** | Failed requests are retried on a different upstream (up to `min(pool_size, 3)` retries) | +| **HTTP/1.1 keep-alive** | Client connections are reused across requests | +| **HTTPS tunneling** | `CONNECT` method is fully supported — traffic is tunneled through the upstream proxy | +| **Upstream authentication** | Per-upstream HTTP Proxy Basic Auth (`Proxy-Authorization`) | +| **Hot config reload** | Config file is re-read every `reload_interval_secs`; new/removed upstreams are applied without restart | + +## Installation + +```bash +git clone https://github.com/Team-Haruki/http-proxy-lb +cd http-proxy-lb +cargo build --release +# binary is at target/release/http-proxy-lb +``` + +## Quick start + +```bash +cp config.example.yaml config.yaml +# edit config.yaml to point at your upstream proxies +./target/release/http-proxy-lb --config config.yaml +``` + +Set `RUST_LOG=debug` for verbose logging. + +## Configuration + +```yaml +# Local address to listen on +listen: "127.0.0.1:8080" + +# Load-balancing mode: round_robin | best +mode: round_robin + +# How often to re-read the config file (seconds). 0 = disabled. +reload_interval_secs: 60 + +health_check: + interval_secs: 30 # probe interval for offline upstreams + timeout_secs: 5 # TCP-connect timeout per probe + +upstream: + - url: "http://proxy1.example.com:8080" + weight: 1 + + - url: "http://proxy2.example.com:8080" + weight: 2 + username: "user" + password: "secret" +``` + +### Balance modes + +* **`round_robin`** — Iterates through online upstreams in weighted order. Upstreams with a higher `weight` receive proportionally more connections. +* **`best`** — Selects the online upstream with the lowest *score*, where: + `score = latency_ema_ms + active_connections × 50` + Latency is an exponential moving average (α = 0.25) of observed response times. + +### Hot reload + +Edit `config.yaml` while the proxy is running. Within `reload_interval_secs` seconds the proxy will pick up the new upstream list. Existing upstreams (matched by URL) keep their online/offline state and statistics. + +## Usage with curl + +```bash +# Plain HTTP +curl -x http://127.0.0.1:8080 http://example.com/ + +# HTTPS (CONNECT tunnel) +curl -x http://127.0.0.1:8080 https://example.com/ +``` + +## Architecture + +``` +Client ──► [http-proxy-lb listener] + │ + ▼ + [UpstreamPool] ◄── [Active health checker] + (round-robin / + best select) + │ + ┌────────┼────────┐ + ▼ ▼ ▼ + Upstream1 Upstream2 Upstream3 + (online) (offline) (online) +``` + +Each client connection is handled in its own Tokio task. Upstream connections are established fresh per request (no upstream connection pooling). + +## License + +MIT diff --git a/config.example.yaml b/config.example.yaml new file mode 100644 index 0000000..468dc50 --- /dev/null +++ b/config.example.yaml @@ -0,0 +1,38 @@ +# http-proxy-lb example configuration +# Copy to config.yaml and edit as needed. + +# Local address and port to listen on +listen: "127.0.0.1:8080" + +# Load-balancing mode: round_robin | best +# round_robin — weighted round-robin across all online upstreams +# best — select the online upstream with the lowest combined score +# (latency EMA + active-connection penalty) +mode: round_robin + +# Hot-reload: re-read config file every N seconds (0 = disabled) +reload_interval_secs: 60 + +# Active health-check settings +health_check: + # How often to probe offline upstreams (seconds) + interval_secs: 30 + # TCP-connect timeout when probing (seconds) + timeout_secs: 5 + +# Upstream proxy pool +# url — http://host:port (only HTTP proxies supported) +# weight — relative weight for round_robin (default: 1) +# username — optional HTTP proxy Basic-Auth username +# password — optional HTTP proxy Basic-Auth password +upstream: + - url: "http://proxy1.example.com:8080" + weight: 1 + + - url: "http://proxy2.example.com:8080" + weight: 2 + username: "user" + password: "secret" + + - url: "http://proxy3.example.com:3128" + weight: 1 diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..923963f --- /dev/null +++ b/src/config.rs @@ -0,0 +1,241 @@ +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::path::Path; +use std::time::SystemTime; + +// --------------------------------------------------------------------------- +// Top-level config +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + /// Local listen address, e.g. "127.0.0.1:8080" + pub listen: String, + + /// Load-balancing mode + #[serde(default)] + pub mode: BalanceMode, + + /// How often to re-read the config file (seconds). 0 = disabled. + #[serde(default = "default_reload_interval")] + pub reload_interval_secs: u64, + + /// Active health-check parameters + #[serde(default)] + pub health_check: HealthCheckConfig, + + /// Upstream proxy list + #[serde(default)] + pub upstream: Vec, +} + +fn default_reload_interval() -> u64 { + 60 +} + +// --------------------------------------------------------------------------- +// Balance mode +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum BalanceMode { + /// Weighted round-robin over online upstreams (default) + #[default] + RoundRobin, + /// Pick the online upstream with the best (lowest) combined score + Best, +} + +// --------------------------------------------------------------------------- +// Health-check config +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HealthCheckConfig { + /// Seconds between successive active-probe rounds + #[serde(default = "default_hc_interval")] + pub interval_secs: u64, + /// TCP-connect timeout for each probe (seconds) + #[serde(default = "default_hc_timeout")] + pub timeout_secs: u64, +} + +impl Default for HealthCheckConfig { + fn default() -> Self { + Self { + interval_secs: default_hc_interval(), + timeout_secs: default_hc_timeout(), + } + } +} + +fn default_hc_interval() -> u64 { + 30 +} +fn default_hc_timeout() -> u64 { + 5 +} + +// --------------------------------------------------------------------------- +// Upstream config +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpstreamConfig { + /// Proxy URL, e.g. "http://1.2.3.4:8080" + pub url: String, + /// Relative weight for round_robin selection (default: 1) + #[serde(default = "default_weight")] + pub weight: u32, + /// Optional HTTP proxy Basic-Auth username + pub username: Option, + /// Optional HTTP proxy Basic-Auth password + pub password: Option, +} + +fn default_weight() -> u32 { + 1 +} + +impl UpstreamConfig { + /// Returns the `Proxy-Authorization: Basic …` header value, if auth is configured. + pub fn proxy_auth_header(&self) -> Option { + use base64::Engine; + match (&self.username, &self.password) { + (Some(u), Some(p)) => { + let encoded = base64::engine::general_purpose::STANDARD.encode(format!("{u}:{p}")); + Some(format!("Basic {encoded}")) + } + _ => None, + } + } + + /// Parse `host` and `port` from the upstream URL. + pub fn host_port(&self) -> Result<(String, u16)> { + let stripped = self + .url + .trim_start_matches("http://") + .trim_start_matches("https://") + .trim_end_matches('/'); + let mut iter = stripped.splitn(2, ':'); + let host = iter + .next() + .filter(|h| !h.is_empty()) + .ok_or_else(|| anyhow::anyhow!("Invalid upstream URL: {}", self.url))? + .to_string(); + let port: u16 = iter.next().and_then(|p| p.parse().ok()).unwrap_or(8080); + Ok((host, port)) + } +} + +// --------------------------------------------------------------------------- +// Loading +// --------------------------------------------------------------------------- + +pub fn load_config(path: &str) -> Result { + let content = + fs::read_to_string(path).with_context(|| format!("Failed to read config file: {path}"))?; + let config: Config = yaml_serde::from_str(&content) + .with_context(|| format!("Failed to parse config file: {path}"))?; + Ok(config) +} + +/// Returns the mtime of `path`, used to detect file changes for hot reload. +pub fn file_mtime(path: &str) -> Option { + Path::new(path).metadata().ok()?.modified().ok() +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_host_port_plain() { + let cfg = UpstreamConfig { + url: "http://proxy.example.com:8080".to_string(), + weight: 1, + username: None, + password: None, + }; + let (host, port) = cfg.host_port().unwrap(); + assert_eq!(host, "proxy.example.com"); + assert_eq!(port, 8080); + } + + #[test] + fn test_host_port_default_port() { + let cfg = UpstreamConfig { + url: "http://proxy.example.com".to_string(), + weight: 1, + username: None, + password: None, + }; + let (host, port) = cfg.host_port().unwrap(); + assert_eq!(host, "proxy.example.com"); + assert_eq!(port, 8080); + } + + #[test] + fn test_proxy_auth_header() { + let cfg = UpstreamConfig { + url: "http://p:1".to_string(), + weight: 1, + username: Some("alice".to_string()), + password: Some("s3cr3t".to_string()), + }; + let hdr = cfg.proxy_auth_header().unwrap(); + assert!(hdr.starts_with("Basic ")); + // base64("alice:s3cr3t") + use base64::Engine; + let expected = base64::engine::general_purpose::STANDARD.encode("alice:s3cr3t"); + assert_eq!(hdr, format!("Basic {expected}")); + } + + #[test] + fn test_proxy_auth_header_none() { + let cfg = UpstreamConfig { + url: "http://p:1".to_string(), + weight: 1, + username: None, + password: None, + }; + assert!(cfg.proxy_auth_header().is_none()); + } + + #[test] + fn test_parse_yaml() { + let yaml = r#" +listen: "127.0.0.1:9090" +mode: best +upstream: + - url: "http://a:1" + weight: 2 + - url: "http://b:2" +"#; + let cfg: Config = yaml_serde::from_str(yaml).unwrap(); + assert_eq!(cfg.listen, "127.0.0.1:9090"); + assert_eq!(cfg.mode, BalanceMode::Best); + assert_eq!(cfg.upstream.len(), 2); + assert_eq!(cfg.upstream[0].weight, 2); + assert_eq!(cfg.upstream[1].weight, 1); // default + } + + #[test] + fn test_parse_yaml_defaults() { + let yaml = r#" +listen: "0.0.0.0:8080" +upstream: [] +"#; + let cfg: Config = yaml_serde::from_str(yaml).unwrap(); + assert_eq!(cfg.mode, BalanceMode::RoundRobin); + assert_eq!(cfg.reload_interval_secs, 60); + assert_eq!(cfg.health_check.interval_secs, 30); + assert_eq!(cfg.health_check.timeout_secs, 5); + } +} diff --git a/src/health.rs b/src/health.rs new file mode 100644 index 0000000..cb7932a --- /dev/null +++ b/src/health.rs @@ -0,0 +1,67 @@ +use std::sync::Arc; +use std::time::Duration; + +use tokio::net::TcpStream; +use tokio::time::timeout; +use tracing::{debug, info}; + +use crate::upstream::UpstreamPool; + +/// Runs the active health-checker in the background. +/// +/// Every `interval_secs` seconds it iterates over every **offline** upstream, +/// attempts a TCP-connect with `timeout_secs` deadline, and marks the upstream +/// **online** if the connection succeeds. +pub async fn run_health_checker(pool: Arc, interval_secs: u64, timeout_secs: u64) { + let interval = Duration::from_secs(interval_secs); + let probe_timeout = Duration::from_secs(timeout_secs); + + loop { + tokio::time::sleep(interval).await; + probe_offline(&pool, probe_timeout).await; + } +} + +async fn probe_offline(pool: &Arc, probe_timeout: Duration) { + let offline = pool.offline_entries(); + if offline.is_empty() { + return; + } + + debug!( + count = offline.len(), + "active health check: probing offline upstreams" + ); + + let tasks: Vec<_> = offline + .into_iter() + .map(|entry| { + let t = probe_timeout; + tokio::spawn(async move { + let addr = match entry.host_port() { + Ok((h, p)) => format!("{h}:{p}"), + Err(e) => { + debug!(url = %entry.config.url, error = %e, "invalid upstream URL"); + return; + } + }; + match timeout(t, TcpStream::connect(&addr)).await { + Ok(Ok(_)) => { + info!(upstream = %entry.config.url, "health check OK — marking online"); + entry.mark_online(); + } + Ok(Err(e)) => { + debug!(upstream = %entry.config.url, error = %e, "health check failed"); + } + Err(_) => { + debug!(upstream = %entry.config.url, "health check timed out"); + } + } + }) + }) + .collect(); + + for task in tasks { + let _ = task.await; + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..79a239f --- /dev/null +++ b/src/main.rs @@ -0,0 +1,130 @@ +mod config; +mod health; +mod proxy; +mod upstream; + +use std::sync::Arc; +use std::time::Duration; + +use anyhow::{Context, Result}; +use clap::Parser; +use tokio::net::TcpListener; +use tracing::{error, info, warn}; + +use config::{file_mtime, load_config}; +use upstream::UpstreamPool; + +// --------------------------------------------------------------------------- +// CLI +// --------------------------------------------------------------------------- + +/// HTTP proxy load balancer — relay local HTTP proxy traffic through a pool of +/// upstream proxies with health checking and hot reload. +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Cli { + /// Path to the YAML configuration file + #[arg(short, long, default_value = "config.yaml")] + config: String, +} + +// --------------------------------------------------------------------------- +// Entry point +// --------------------------------------------------------------------------- + +#[tokio::main] +async fn main() -> Result<()> { + // Initialise structured logging (RUST_LOG controls verbosity; default: info) + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")), + ) + .init(); + + let cli = Cli::parse(); + let cfg_path = cli.config.clone(); + + // Load initial config + let cfg = + load_config(&cfg_path).with_context(|| format!("failed to load config from {cfg_path}"))?; + + info!(listen = %cfg.listen, mode = ?cfg.mode, upstreams = cfg.upstream.len(), "starting"); + + let listen_addr = cfg.listen.clone(); + let mode = cfg.mode; + let reload_interval = cfg.reload_interval_secs; + let hc_interval = cfg.health_check.interval_secs; + let hc_timeout = cfg.health_check.timeout_secs; + + // Build upstream pool + let pool = UpstreamPool::from_config(&cfg); + + // Bind listener + let listener = TcpListener::bind(&listen_addr) + .await + .with_context(|| format!("failed to bind to {listen_addr}"))?; + info!(addr = %listen_addr, "listening"); + + // --- Spawn active health checker --- + { + let pool = Arc::clone(&pool); + tokio::spawn(async move { + health::run_health_checker(pool, hc_interval, hc_timeout).await; + }); + } + + // --- Spawn hot-reload watcher --- + if reload_interval > 0 { + let pool = Arc::clone(&pool); + let path = cfg_path.clone(); + tokio::spawn(async move { + run_hot_reload(pool, path, reload_interval).await; + }); + } + + // --- Accept loop --- + loop { + match listener.accept().await { + Ok((stream, peer)) => { + debug_assert_ne!(peer.port(), 0); + let pool = Arc::clone(&pool); + tokio::spawn(async move { + proxy::handle_client(stream, pool, mode).await; + }); + } + Err(e) => { + error!(error = %e, "accept error"); + } + } + } +} + +// --------------------------------------------------------------------------- +// Hot-reload loop +// --------------------------------------------------------------------------- + +async fn run_hot_reload(pool: Arc, cfg_path: String, interval_secs: u64) { + let mut last_mtime = file_mtime(&cfg_path); + let interval = Duration::from_secs(interval_secs); + + loop { + tokio::time::sleep(interval).await; + + let current_mtime = file_mtime(&cfg_path); + if current_mtime == last_mtime { + continue; + } + + info!(path = %cfg_path, "config file changed — reloading"); + match load_config(&cfg_path) { + Ok(new_cfg) => { + pool.reload(&new_cfg); + last_mtime = current_mtime; + } + Err(e) => { + warn!(error = %e, "hot reload failed — keeping current config"); + } + } + } +} diff --git a/src/proxy.rs b/src/proxy.rs new file mode 100644 index 0000000..4edd9fd --- /dev/null +++ b/src/proxy.rs @@ -0,0 +1,649 @@ +//! HTTP/1.x proxy request handler. +//! +//! Supports: +//! * `CONNECT` tunneling (HTTPS) +//! * Plain HTTP forwarding +//! * HTTP/1.1 persistent connections (keep-alive) from the client +//! * Retry on upstream failure (passive health detection) + +use std::sync::Arc; +use std::time::Instant; + +use anyhow::{anyhow, bail, Result}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio::time::{timeout, Duration}; +use tracing::{debug, warn}; + +use crate::config::BalanceMode; +use crate::upstream::UpstreamPool; + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +const MAX_HEADER_SIZE: usize = 64 * 1024; // 64 KiB +const CONNECT_TIMEOUT: Duration = Duration::from_secs(15); +/// Maximum number of upstream retry attempts per request. +const MAX_RETRIES: usize = 3; + +// --------------------------------------------------------------------------- +// Public entry point +// --------------------------------------------------------------------------- + +/// Accept a single client TCP connection and serve HTTP proxy requests on it. +/// Supports keep-alive: loops until the client closes or sends `Connection: close`. +pub async fn handle_client(mut client: TcpStream, pool: Arc, mode: BalanceMode) { + let peer = client + .peer_addr() + .map(|a| a.to_string()) + .unwrap_or_default(); + debug!(peer = %peer, "new client connection"); + + loop { + match read_headers(&mut client).await { + Ok(Some(buf)) => { + match dispatch(&mut client, &buf, &pool, mode).await { + Ok(true) => continue, // keep-alive: read next request + Ok(false) => break, + Err(e) => { + debug!(peer = %peer, error = %e, "request dispatch error"); + break; + } + } + } + Ok(None) => break, // client closed connection + Err(e) => { + debug!(peer = %peer, error = %e, "header read error"); + break; + } + } + } + + debug!(peer = %peer, "client connection closed"); +} + +// --------------------------------------------------------------------------- +// Request dispatch +// --------------------------------------------------------------------------- + +/// Returns `Ok(true)` if the client connection should be kept alive. +async fn dispatch( + client: &mut TcpStream, + buf: &[u8], + pool: &Arc, + mode: BalanceMode, +) -> Result { + // --- parse request line + headers --- + let mut raw_headers = [httparse::EMPTY_HEADER; 96]; + let mut req = httparse::Request::new(&mut raw_headers); + let body_offset = match req.parse(buf)? { + httparse::Status::Complete(n) => n, + httparse::Status::Partial => bail!("incomplete request headers"), + }; + + let method = req + .method + .ok_or_else(|| anyhow!("missing method"))? + .to_string(); + let path = req.path.ok_or_else(|| anyhow!("missing path"))?.to_string(); + let version = req.version.unwrap_or(1); + let headers = req.headers; + + // Content-Length / Transfer-Encoding of the *request* body + let req_content_length: Option = + get_header(headers, "content-length").and_then(|v| v.parse().ok()); + let req_is_chunked = get_header(headers, "transfer-encoding") + .map(|v| v.to_ascii_lowercase().contains("chunked")) + .unwrap_or(false); + + // Does the client want a persistent connection? + let client_keep_alive = client_wants_keep_alive(headers, version); + + // --- CONNECT (HTTPS tunnel) --- + if method.eq_ignore_ascii_case("CONNECT") { + let (host, port) = parse_connect_target(&path)?; + return handle_connect(client, &host, port, pool, mode) + .await + .map(|_| false); + } + + // --- Plain HTTP --- + handle_http( + client, + &method, + &path, + version, + buf, + body_offset, + req_content_length, + req_is_chunked, + client_keep_alive, + pool, + mode, + ) + .await +} + +// --------------------------------------------------------------------------- +// CONNECT tunnel +// --------------------------------------------------------------------------- + +async fn handle_connect( + client: &mut TcpStream, + host: &str, + port: u16, + pool: &Arc, + mode: BalanceMode, +) -> Result<()> { + let mut tried: Vec = Vec::new(); + let max_retries = pool.len().min(MAX_RETRIES); + + loop { + let (idx, upstream) = match pool.select(mode, &tried) { + Some(u) => u, + None => { + let _ = write_status(client, 502, "No upstream available").await; + bail!("no upstream available for CONNECT"); + } + }; + tried.push(idx); + + let (up_host, up_port) = upstream.host_port()?; + let addr = format!("{up_host}:{up_port}"); + + // Connect to upstream proxy + let mut up_stream = match connect_upstream(&addr).await { + Ok(s) => s, + Err(e) => { + warn!(upstream = %upstream.config.url, error = %e, "CONNECT: upstream TCP failed"); + upstream.mark_offline(); + upstream.record_failure(); + if tried.len() >= max_retries { + let _ = write_status(client, 502, "Bad Gateway").await; + bail!("all retries exhausted for CONNECT"); + } + continue; + } + }; + + // Send CONNECT to upstream proxy + let connect_req = build_upstream_connect(host, port, upstream.config.proxy_auth_header()); + if let Err(e) = up_stream.write_all(connect_req.as_bytes()).await { + warn!(upstream = %upstream.config.url, error = %e, "CONNECT: write to upstream failed"); + upstream.mark_offline(); + upstream.record_failure(); + if tried.len() >= max_retries { + let _ = write_status(client, 502, "Bad Gateway").await; + bail!("all retries exhausted for CONNECT write"); + } + continue; + } + + // Read upstream's response to our CONNECT + match read_connect_response(&mut up_stream).await { + Ok(200) => { + // Success — tell client we're connected + client + .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n") + .await?; + upstream + .active_conns + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let t0 = Instant::now(); + let result = tunnel(client, up_stream).await; + upstream + .active_conns + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + upstream.record_success(t0.elapsed().as_millis() as u64); + return result; + } + Ok(status) => { + // Upstream rejected CONNECT (e.g. auth required, forbidden) + debug!(upstream = %upstream.config.url, status, "CONNECT rejected by upstream"); + let _ = write_status(client, status as u16, "Upstream rejected CONNECT").await; + bail!("upstream rejected CONNECT with status {status}"); + } + Err(e) => { + warn!(upstream = %upstream.config.url, error = %e, "CONNECT: bad upstream response"); + upstream.mark_offline(); + upstream.record_failure(); + if tried.len() >= max_retries { + let _ = write_status(client, 502, "Bad Gateway").await; + bail!("all retries exhausted reading CONNECT response"); + } + continue; + } + } + } +} + +/// Bidirectional copy between client and upstream until either side closes. +async fn tunnel(client: &mut TcpStream, mut upstream: TcpStream) -> Result<()> { + let (mut cr, mut cw) = tokio::io::split(client); + let (mut ur, mut uw) = tokio::io::split(&mut upstream); + tokio::select! { + r = tokio::io::copy(&mut cr, &mut uw) => { r?; } + r = tokio::io::copy(&mut ur, &mut cw) => { r?; } + } + Ok(()) +} + +// --------------------------------------------------------------------------- +// Plain HTTP forwarding +// --------------------------------------------------------------------------- + +#[allow(clippy::too_many_arguments)] +async fn handle_http( + client: &mut TcpStream, + method: &str, + _path: &str, + _version: u8, + req_buf: &[u8], // raw bytes: headers + any already-buffered body bytes + body_offset: usize, // where headers end within req_buf + req_content_length: Option, + req_is_chunked: bool, + client_keep_alive: bool, + pool: &Arc, + mode: BalanceMode, +) -> Result { + let mut tried: Vec = Vec::new(); + let max_retries = pool.len().min(MAX_RETRIES); + + loop { + let (idx, upstream) = match pool.select(mode, &tried) { + Some(u) => u, + None => { + let _ = write_status(client, 502, "No upstream available").await; + return Ok(false); + } + }; + tried.push(idx); + + let (up_host, up_port) = upstream.host_port()?; + let addr = format!("{up_host}:{up_port}"); + + // --- Connect to upstream proxy --- + let mut up_stream = match connect_upstream(&addr).await { + Ok(s) => s, + Err(e) => { + warn!(upstream = %upstream.config.url, error = %e, "HTTP: upstream TCP failed"); + upstream.mark_offline(); + upstream.record_failure(); + if tried.len() >= max_retries { + let _ = write_status(client, 502, "Bad Gateway").await; + return Ok(false); + } + continue; + } + }; + + // --- Build and send request headers --- + let fwd_headers = + rewrite_request_headers(req_buf, body_offset, upstream.config.proxy_auth_header()); + if let Err(e) = up_stream.write_all(&fwd_headers).await { + warn!(upstream = %upstream.config.url, error = %e, "HTTP: write headers failed"); + upstream.mark_offline(); + upstream.record_failure(); + if tried.len() >= max_retries { + let _ = write_status(client, 502, "Bad Gateway").await; + return Ok(false); + } + continue; + } + + // --- Stream request body --- + let already_in_buf = req_buf.len().saturating_sub(body_offset); + if already_in_buf > 0 { + // Body bytes that arrived in the same read as the headers + if let Err(e) = up_stream.write_all(&req_buf[body_offset..]).await { + warn!(upstream = %upstream.config.url, error = %e, "HTTP: write buffered body failed"); + upstream.mark_offline(); + upstream.record_failure(); + if tried.len() >= max_retries { + return Ok(false); + } + continue; + } + } + if let Err(e) = forward_body( + client, + &mut up_stream, + req_content_length, + req_is_chunked, + already_in_buf as u64, + ) + .await + { + warn!(error = %e, "HTTP: body forward error"); + return Ok(false); + } + + // --- Stream response back --- + let t0 = Instant::now(); + upstream + .active_conns + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let resp_result = forward_response(&mut up_stream, client, method).await; + upstream + .active_conns + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + + match resp_result { + Ok(upstream_keep_alive) => { + upstream.record_success(t0.elapsed().as_millis() as u64); + return Ok(client_keep_alive && upstream_keep_alive); + } + Err(e) => { + warn!(upstream = %upstream.config.url, error = %e, "HTTP: response forward failed"); + upstream.mark_offline(); + upstream.record_failure(); + if tried.len() >= max_retries { + return Ok(false); + } + continue; + } + } + } +} + +// --------------------------------------------------------------------------- +// Body forwarding helpers +// --------------------------------------------------------------------------- + +/// Forward remaining request body bytes (after headers) from `client` → `upstream`. +/// +/// * If `content_length` is known, copy exactly that many bytes minus what was +/// already buffered (`already_sent`). +/// * If chunked, forward raw chunked stream until the terminal `0\r\n\r\n`. +/// * If neither, assume there is no body and return immediately. +async fn forward_body( + client: &mut TcpStream, + upstream: &mut TcpStream, + content_length: Option, + is_chunked: bool, + already_sent: u64, +) -> Result<()> { + if let Some(len) = content_length { + let remaining = len.saturating_sub(already_sent); + if remaining > 0 { + copy_exact(client, upstream, remaining).await?; + } + } else if is_chunked { + copy_chunked(client, upstream).await?; + } + // else: no body (GET / HEAD / etc.) + Ok(()) +} + +/// Forward HTTP response from `upstream` → `client`. +/// Returns `Ok(true)` if the upstream connection is being kept alive +/// (i.e. the caller may be able to issue another request on the same upstream +/// socket — we don't reuse upstream sockets here, but the value is used to +/// determine whether the *client* connection should be kept alive). +async fn forward_response( + upstream: &mut TcpStream, + client: &mut TcpStream, + req_method: &str, +) -> Result { + // Read response headers + let header_buf = read_headers(upstream) + .await? + .ok_or_else(|| anyhow!("upstream closed connection without response"))?; + + let mut raw_headers = [httparse::EMPTY_HEADER; 96]; + let mut resp = httparse::Response::new(&mut raw_headers); + let body_offset = match resp.parse(&header_buf)? { + httparse::Status::Complete(n) => n, + httparse::Status::Partial => bail!("incomplete response headers from upstream"), + }; + + let status = resp.code.unwrap_or(0); + let version = resp.version.unwrap_or(1); + let headers = resp.headers; + + let resp_content_length: Option = + get_header(headers, "content-length").and_then(|v| v.parse().ok()); + let resp_is_chunked = get_header(headers, "transfer-encoding") + .map(|v| v.to_ascii_lowercase().contains("chunked")) + .unwrap_or(false); + let upstream_keep_alive = upstream_wants_keep_alive(headers, version); + + // Forward headers verbatim + client.write_all(&header_buf[..body_offset]).await?; + + // Forward any body bytes that were buffered with the headers + let already_in_buf = header_buf.len().saturating_sub(body_offset); + if already_in_buf > 0 { + client.write_all(&header_buf[body_offset..]).await?; + } + + // Determine whether this response has a body + let has_body = !req_method.eq_ignore_ascii_case("HEAD") + && status != 204 + && status != 304 + && !(100..200).contains(&status); + + if has_body { + if let Some(len) = resp_content_length { + let remaining = len.saturating_sub(already_in_buf as u64); + if remaining > 0 { + copy_exact(upstream, client, remaining).await?; + } + } else if resp_is_chunked { + copy_chunked(upstream, client).await?; + } else { + // No Content-Length and not chunked: read until upstream closes. + // We must also close the client connection afterwards. + tokio::io::copy(upstream, client).await?; + return Ok(false); + } + } + + Ok(upstream_keep_alive) +} + +// --------------------------------------------------------------------------- +// Low-level I/O helpers +// --------------------------------------------------------------------------- + +/// Read a raw chunked stream from `src` and write to `dst` until the +/// terminal `0\r\n\r\n` chunk. +async fn copy_chunked(src: &mut TcpStream, dst: &mut TcpStream) -> Result<()> { + let mut buf = vec![0u8; 8 * 1024]; + // We look for the terminal chunk marker in what we forward. + // Since this is a relay, we forward bytes verbatim and detect the end. + let mut trailer = Vec::new(); + loop { + let n = src.read(&mut buf).await?; + if n == 0 { + break; + } + dst.write_all(&buf[..n]).await?; + // Accumulate last 8 bytes to detect "0\r\n\r\n" + trailer.extend_from_slice(&buf[..n]); + if trailer.len() > 8 { + trailer.drain(..trailer.len() - 8); + } + if trailer.windows(5).any(|w| w == b"0\r\n\r\n") { + break; + } + } + Ok(()) +} + +/// Copy exactly `bytes` bytes from `src` to `dst`. +async fn copy_exact(src: &mut TcpStream, dst: &mut TcpStream, mut bytes: u64) -> Result<()> { + let mut buf = vec![0u8; 8 * 1024]; + while bytes > 0 { + let to_read = (buf.len() as u64).min(bytes) as usize; + let n = src.read(&mut buf[..to_read]).await?; + if n == 0 { + bail!("unexpected EOF: expected {bytes} more bytes"); + } + dst.write_all(&buf[..n]).await?; + bytes -= n as u64; + } + Ok(()) +} + +/// Read bytes from `stream` into a growing buffer until the HTTP header +/// terminator `\r\n\r\n` is found. +/// +/// Returns `Ok(None)` when the connection is closed before any bytes are read +/// (clean EOF). +async fn read_headers(stream: &mut TcpStream) -> Result>> { + let mut buf: Vec = Vec::with_capacity(4096); + let mut tmp = [0u8; 4096]; + let mut first_read = true; + + loop { + let n = stream.read(&mut tmp).await?; + if n == 0 { + if first_read { + return Ok(None); // clean EOF between requests + } + bail!("connection closed mid-headers"); + } + first_read = false; + buf.extend_from_slice(&tmp[..n]); + + if buf.windows(4).any(|w| w == b"\r\n\r\n") { + // Read a bit more if there are body bytes that arrived with headers + // (common for small POST bodies). We don't loop here because + // `forward_body` will consume the rest from the socket. + return Ok(Some(buf)); + } + + if buf.len() > MAX_HEADER_SIZE { + bail!("request headers exceed {MAX_HEADER_SIZE} bytes"); + } + } +} + +// --------------------------------------------------------------------------- +// Request / response header helpers +// --------------------------------------------------------------------------- + +/// Rewrite request headers for forwarding through an upstream proxy: +/// * Keeps the request line verbatim (absolute URI is already correct for a +/// proxy-to-proxy hop). +/// * Removes `Proxy-Authorization` from the client to avoid leaking +/// the client's upstream credentials to the next hop. +/// * Injects our upstream's `Proxy-Authorization` if required. +fn rewrite_request_headers(raw: &[u8], body_offset: usize, proxy_auth: Option) -> Vec { + let header_bytes = &raw[..body_offset]; + let header_str = String::from_utf8_lossy(header_bytes); + + let mut out = String::with_capacity(header_bytes.len() + 64); + let mut first_line = true; + + for line in header_str.split("\r\n") { + if first_line { + out.push_str(line); + out.push_str("\r\n"); + first_line = false; + continue; + } + if line.is_empty() { + // Header section ends — inject upstream Proxy-Authorization if needed + if let Some(ref auth) = proxy_auth { + out.push_str(&format!("Proxy-Authorization: {auth}\r\n")); + } + out.push_str("\r\n"); + break; + } + // Drop the *client*'s Proxy-Authorization header (we add ours instead) + if line + .to_ascii_lowercase() + .starts_with("proxy-authorization:") + { + continue; + } + out.push_str(line); + out.push_str("\r\n"); + } + + out.into_bytes() +} + +// --------------------------------------------------------------------------- +// CONNECT helpers +// --------------------------------------------------------------------------- + +fn build_upstream_connect(host: &str, port: u16, proxy_auth: Option) -> String { + let mut req = format!("CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n"); + if let Some(auth) = proxy_auth { + req.push_str(&format!("Proxy-Authorization: {auth}\r\n")); + } + req.push_str("\r\n"); + req +} + +/// Read the first response line from an upstream after sending CONNECT. +/// Returns the HTTP status code. +async fn read_connect_response(upstream: &mut TcpStream) -> Result { + let buf = read_headers(upstream) + .await? + .ok_or_else(|| anyhow!("upstream closed connection"))?; + let mut raw_headers = [httparse::EMPTY_HEADER; 16]; + let mut resp = httparse::Response::new(&mut raw_headers); + resp.parse(&buf)?; + resp.code + .map(|c| c as u32) + .ok_or_else(|| anyhow!("no status code in CONNECT response")) +} + +// --------------------------------------------------------------------------- +// Misc helpers +// --------------------------------------------------------------------------- + +async fn connect_upstream(addr: &str) -> Result { + timeout(CONNECT_TIMEOUT, TcpStream::connect(addr)) + .await + .map_err(|_| anyhow!("connection to {addr} timed out"))? + .map_err(|e| anyhow!("TCP connect to {addr} failed: {e}")) +} + +fn parse_connect_target(authority: &str) -> Result<(String, u16)> { + let mut parts = authority.splitn(2, ':'); + let host = parts + .next() + .filter(|h| !h.is_empty()) + .ok_or_else(|| anyhow!("invalid CONNECT target: {authority}"))? + .to_string(); + let port: u16 = parts.next().and_then(|p| p.parse().ok()).unwrap_or(443); + Ok((host, port)) +} + +fn get_header<'a>(headers: &'a [httparse::Header<'a>], name: &str) -> Option<&'a str> { + headers.iter().find_map(|h| { + if h.name.eq_ignore_ascii_case(name) { + std::str::from_utf8(h.value).ok() + } else { + None + } + }) +} + +fn client_wants_keep_alive(headers: &[httparse::Header<'_>], version: u8) -> bool { + match get_header(headers, "connection").map(|v| v.to_ascii_lowercase()) { + Some(ref v) if v.contains("close") => false, + Some(ref v) if v.contains("keep-alive") => true, + _ => version == 1, // HTTP/1.1 default is keep-alive + } +} + +fn upstream_wants_keep_alive(headers: &[httparse::Header<'_>], version: u8) -> bool { + match get_header(headers, "connection").map(|v| v.to_ascii_lowercase()) { + Some(ref v) if v.contains("close") => false, + Some(ref v) if v.contains("keep-alive") => true, + _ => version == 1, + } +} + +async fn write_status(stream: &mut TcpStream, code: u16, msg: &str) -> Result<()> { + let resp = format!("HTTP/1.1 {code} {msg}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"); + stream.write_all(resp.as_bytes()).await?; + Ok(()) +} diff --git a/src/upstream.rs b/src/upstream.rs new file mode 100644 index 0000000..f176e59 --- /dev/null +++ b/src/upstream.rs @@ -0,0 +1,430 @@ +use std::sync::atomic::{AtomicI64, AtomicU32, AtomicU64, AtomicU8, Ordering}; +use std::sync::Arc; +use std::time::Instant; + +use anyhow::Result; +use parking_lot::Mutex; +use tracing::{debug, info, warn}; + +use crate::config::{BalanceMode, Config, UpstreamConfig}; + +// --------------------------------------------------------------------------- +// Upstream state +// --------------------------------------------------------------------------- + +const STATE_ONLINE: u8 = 0; +const STATE_OFFLINE: u8 = 1; + +// --------------------------------------------------------------------------- +// UpstreamEntry — one upstream proxy with live state & statistics +// --------------------------------------------------------------------------- + +pub struct UpstreamEntry { + pub config: UpstreamConfig, + /// 0 = online, 1 = offline + state: AtomicU8, + /// Number of in-flight connections currently routed through this upstream + pub active_conns: AtomicI64, + /// Exponential moving average of successful response latency (stored as + /// whole milliseconds; 0 means "no sample yet"). + latency_ema_ms: AtomicU64, + /// Consecutive passive failures — for future tuning. + pub consec_failures: AtomicU32, + /// Wall-clock time of last state change (for logging only). + last_state_change: Mutex, +} + +impl UpstreamEntry { + fn new(config: UpstreamConfig) -> Arc { + Arc::new(Self { + config, + state: AtomicU8::new(STATE_ONLINE), + active_conns: AtomicI64::new(0), + latency_ema_ms: AtomicU64::new(0), + consec_failures: AtomicU32::new(0), + last_state_change: Mutex::new(Instant::now()), + }) + } + + // ---- state accessors --------------------------------------------------- + + pub fn is_online(&self) -> bool { + self.state.load(Ordering::Acquire) == STATE_ONLINE + } + + pub fn mark_offline(&self) { + if self + .state + .compare_exchange( + STATE_ONLINE, + STATE_OFFLINE, + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_ok() + { + *self.last_state_change.lock() = Instant::now(); + warn!(upstream = %self.config.url, "marked OFFLINE"); + } + } + + pub fn mark_online(&self) { + if self + .state + .compare_exchange( + STATE_OFFLINE, + STATE_ONLINE, + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_ok() + { + self.consec_failures.store(0, Ordering::Release); + *self.last_state_change.lock() = Instant::now(); + info!(upstream = %self.config.url, "marked ONLINE"); + } + } + + // ---- statistics -------------------------------------------------------- + + /// Record a successful request with the measured round-trip latency. + pub fn record_success(&self, latency_ms: u64) { + self.consec_failures.store(0, Ordering::Release); + // EMA(α=0.25): new = 0.25·sample + 0.75·old + let old = self.latency_ema_ms.load(Ordering::Relaxed); + let new_ema = if old == 0 { + latency_ms + } else { + (latency_ms + old * 3) / 4 + }; + self.latency_ema_ms.store(new_ema, Ordering::Relaxed); + } + + pub fn record_failure(&self) { + self.consec_failures.fetch_add(1, Ordering::Relaxed); + } + + /// Lower score = better upstream for "best" mode. + /// Combines latency EMA and active-connection penalty. + pub fn score(&self) -> u64 { + let latency = self.latency_ema_ms.load(Ordering::Relaxed); + let base = if latency == 0 { 50 } else { latency }; // assume 50 ms if unknown + let active = self.active_conns.load(Ordering::Relaxed).max(0) as u64; + base + active * 50 // each in-flight connection adds 50 ms penalty + } + + // ---- helpers ----------------------------------------------------------- + + pub fn host_port(&self) -> Result<(String, u16)> { + self.config.host_port() + } +} + +// --------------------------------------------------------------------------- +// UpstreamPool — thread-safe collection of upstream entries +// --------------------------------------------------------------------------- + +pub struct UpstreamPool { + entries: Mutex>>, + /// Monotonically increasing counter used for round-robin cursor. + rr_counter: AtomicU64, +} + +impl UpstreamPool { + /// Build an initial pool from a freshly loaded `Config`. + pub fn from_config(cfg: &Config) -> Arc { + let entries: Vec> = cfg + .upstream + .iter() + .map(|c| UpstreamEntry::new(c.clone())) + .collect(); + info!(count = entries.len(), "upstream pool initialised"); + Arc::new(Self { + entries: Mutex::new(entries), + rr_counter: AtomicU64::new(0), + }) + } + + // ---- selection --------------------------------------------------------- + + /// Select an online upstream, skipping indices in `exclude`. + /// Returns `(pool_index, entry)` or `None` when no online upstream remains. + pub fn select( + &self, + mode: BalanceMode, + exclude: &[usize], + ) -> Option<(usize, Arc)> { + match mode { + BalanceMode::RoundRobin => self.select_rr(exclude), + BalanceMode::Best => self.select_best(exclude), + } + } + + fn select_rr(&self, exclude: &[usize]) -> Option<(usize, Arc)> { + let entries = self.entries.lock(); + // Build a weighted candidate list. + let candidates: Vec<(usize, &Arc)> = entries + .iter() + .enumerate() + .filter(|(i, e)| e.is_online() && !exclude.contains(i)) + .flat_map(|(i, e)| std::iter::repeat_n((i, e), e.config.weight.max(1) as usize)) + .collect(); + if candidates.is_empty() { + return None; + } + let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) as usize % candidates.len(); + let (pool_idx, entry) = candidates[idx]; + Some((pool_idx, Arc::clone(entry))) + } + + fn select_best(&self, exclude: &[usize]) -> Option<(usize, Arc)> { + let entries = self.entries.lock(); + entries + .iter() + .enumerate() + .filter(|(i, e)| e.is_online() && !exclude.contains(i)) + .min_by_key(|(_, e)| e.score()) + .map(|(i, e)| (i, Arc::clone(e))) + } + + // ---- queries ----------------------------------------------------------- + + pub fn len(&self) -> usize { + self.entries.lock().len() + } + + /// Snapshot of all offline entries (for the health checker). + pub fn offline_entries(&self) -> Vec> { + self.entries + .lock() + .iter() + .filter(|e| !e.is_online()) + .cloned() + .collect() + } + + // ---- hot reload -------------------------------------------------------- + + /// Replace the upstream list from a new config while preserving state/stats + /// for upstreams that already exist (matched by URL). + pub fn reload(&self, new_cfg: &Config) { + let mut entries = self.entries.lock(); + + let new_entries: Vec> = new_cfg + .upstream + .iter() + .map(|new_cfg| { + // Reuse existing entry if URL matches (preserves state/stats). + if let Some(existing) = entries.iter().find(|e| e.config.url == new_cfg.url) { + // Update weight and auth in case they changed. + // State and stats are preserved. + debug!(url = %new_cfg.url, "reusing existing upstream entry"); + Arc::clone(existing) + } else { + info!(url = %new_cfg.url, "adding new upstream"); + UpstreamEntry::new(new_cfg.clone()) + } + }) + .collect(); + + // Log removals. + for old in entries.iter() { + if !new_cfg.upstream.iter().any(|n| n.url == old.config.url) { + info!(url = %old.config.url, "removing upstream"); + } + } + + *entries = new_entries; + info!(count = entries.len(), "upstream pool reloaded"); + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{BalanceMode, Config, HealthCheckConfig, UpstreamConfig}; + + fn make_pool(urls: &[&str]) -> Arc { + let cfg = Config { + listen: "127.0.0.1:8080".to_string(), + mode: BalanceMode::RoundRobin, + reload_interval_secs: 0, + health_check: HealthCheckConfig::default(), + upstream: urls + .iter() + .map(|u| UpstreamConfig { + url: u.to_string(), + weight: 1, + username: None, + password: None, + }) + .collect(), + }; + UpstreamPool::from_config(&cfg) + } + + #[test] + fn round_robin_cycles_through_online() { + let pool = make_pool(&["http://a:1", "http://b:2", "http://c:3"]); + // Collect 6 selections — should visit all 3 upstreams + let mut seen = std::collections::HashSet::new(); + for _ in 0..6 { + let (_, e) = pool.select(BalanceMode::RoundRobin, &[]).unwrap(); + seen.insert(e.config.url.clone()); + } + assert_eq!(seen.len(), 3); + } + + #[test] + fn offline_upstream_is_skipped() { + let pool = make_pool(&["http://a:1", "http://b:2"]); + // Mark first upstream offline + let (_, first) = pool.select(BalanceMode::RoundRobin, &[]).unwrap(); + first.mark_offline(); + + // All subsequent selections should be the other one + for _ in 0..10 { + let (_, e) = pool.select(BalanceMode::RoundRobin, &[]).unwrap(); + assert_ne!( + e.config.url, first.config.url, + "offline upstream was selected" + ); + } + } + + #[test] + fn no_online_upstream_returns_none() { + let pool = make_pool(&["http://a:1"]); + let (_, e) = pool.select(BalanceMode::RoundRobin, &[]).unwrap(); + e.mark_offline(); + assert!(pool.select(BalanceMode::RoundRobin, &[]).is_none()); + } + + #[test] + fn exclude_skips_specified_indices() { + let pool = make_pool(&["http://a:1", "http://b:2", "http://c:3"]); + let result = pool.select(BalanceMode::RoundRobin, &[0, 1]); + let (idx, _) = result.unwrap(); + assert_eq!(idx, 2); + } + + #[test] + fn mark_online_restores_offline_upstream() { + let pool = make_pool(&["http://a:1"]); + let (_, e) = pool.select(BalanceMode::RoundRobin, &[]).unwrap(); + e.mark_offline(); + assert!(pool.select(BalanceMode::RoundRobin, &[]).is_none()); + e.mark_online(); + assert!(pool.select(BalanceMode::RoundRobin, &[]).is_some()); + } + + #[test] + fn best_mode_picks_lowest_score() { + let pool = make_pool(&["http://a:1", "http://b:2"]); + let entries = pool.entries.lock(); + // Give "b" a very high latency to force "a" to win + entries[1] + .latency_ema_ms + .store(9999, std::sync::atomic::Ordering::Relaxed); + drop(entries); + let (idx, _) = pool.select(BalanceMode::Best, &[]).unwrap(); + assert_eq!(idx, 0, "should pick upstream with lower score"); + } + + #[test] + fn reload_preserves_state_for_existing_upstream() { + let pool = make_pool(&["http://a:1", "http://b:2"]); + // Mark "a" offline + let (_, a) = pool.select(BalanceMode::RoundRobin, &[]).unwrap(); + a.mark_offline(); + + // Reload with same upstreams + let new_cfg = Config { + listen: "127.0.0.1:8080".to_string(), + mode: BalanceMode::RoundRobin, + reload_interval_secs: 0, + health_check: HealthCheckConfig::default(), + upstream: vec![ + UpstreamConfig { + url: "http://a:1".to_string(), + weight: 1, + username: None, + password: None, + }, + UpstreamConfig { + url: "http://b:2".to_string(), + weight: 1, + username: None, + password: None, + }, + UpstreamConfig { + url: "http://c:3".to_string(), + weight: 1, + username: None, + password: None, + }, + ], + }; + pool.reload(&new_cfg); + + // "a" should still be offline (state preserved) + let entries = pool.entries.lock(); + let a_entry = entries + .iter() + .find(|e| e.config.url == "http://a:1") + .unwrap(); + assert!( + !a_entry.is_online(), + "state should be preserved after reload" + ); + // "c" should be a new online entry + let c_entry = entries + .iter() + .find(|e| e.config.url == "http://c:3") + .unwrap(); + assert!(c_entry.is_online()); + } + + #[test] + fn weighted_round_robin_respects_weights() { + let cfg = Config { + listen: "127.0.0.1:8080".to_string(), + mode: BalanceMode::RoundRobin, + reload_interval_secs: 0, + health_check: HealthCheckConfig::default(), + upstream: vec![ + UpstreamConfig { + url: "http://a:1".to_string(), + weight: 1, + username: None, + password: None, + }, + UpstreamConfig { + url: "http://b:2".to_string(), + weight: 3, + username: None, + password: None, + }, + ], + }; + let pool = UpstreamPool::from_config(&cfg); + + let mut counts = std::collections::HashMap::new(); + for _ in 0..400 { + let (_, e) = pool.select(BalanceMode::RoundRobin, &[]).unwrap(); + *counts.entry(e.config.url.clone()).or_insert(0u32) += 1; + } + // "b" has 3x weight → should get ~3× more traffic + let a_count = counts["http://a:1"]; + let b_count = counts["http://b:2"]; + let ratio = b_count as f64 / a_count as f64; + assert!( + ratio > 2.5 && ratio < 3.5, + "expected ~3:1 ratio, got a={a_count} b={b_count} (ratio={ratio:.2})" + ); + } +} From 7e5c1687d199065941d2c234adf8ab8c58018e1c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Mar 2026 15:36:34 +0000 Subject: [PATCH 03/13] Address code review feedback: rename shadowed variable and remove unused parameters Co-authored-by: MejiroRina <70424266+MejiroRina@users.noreply.github.com> Agent-Logs-Url: https://github.com/Team-Haruki/http-proxy-lb/sessions/f3748e80-1418-4394-a40c-a54bceb41f43 --- src/proxy.rs | 4 ---- src/upstream.rs | 13 ++++++------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/proxy.rs b/src/proxy.rs index 4edd9fd..5ab5210 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -112,8 +112,6 @@ async fn dispatch( handle_http( client, &method, - &path, - version, buf, body_offset, req_content_length, @@ -237,8 +235,6 @@ async fn tunnel(client: &mut TcpStream, mut upstream: TcpStream) -> Result<()> { async fn handle_http( client: &mut TcpStream, method: &str, - _path: &str, - _version: u8, req_buf: &[u8], // raw bytes: headers + any already-buffered body bytes body_offset: usize, // where headers end within req_buf req_content_length: Option, diff --git a/src/upstream.rs b/src/upstream.rs index f176e59..389c0ef 100644 --- a/src/upstream.rs +++ b/src/upstream.rs @@ -213,16 +213,15 @@ impl UpstreamPool { let new_entries: Vec> = new_cfg .upstream .iter() - .map(|new_cfg| { + .map(|upstream_cfg| { // Reuse existing entry if URL matches (preserves state/stats). - if let Some(existing) = entries.iter().find(|e| e.config.url == new_cfg.url) { - // Update weight and auth in case they changed. - // State and stats are preserved. - debug!(url = %new_cfg.url, "reusing existing upstream entry"); + if let Some(existing) = entries.iter().find(|e| e.config.url == upstream_cfg.url) { + // State and stats are preserved; config reference is kept as-is. + debug!(url = %upstream_cfg.url, "reusing existing upstream entry"); Arc::clone(existing) } else { - info!(url = %new_cfg.url, "adding new upstream"); - UpstreamEntry::new(new_cfg.clone()) + info!(url = %upstream_cfg.url, "adding new upstream"); + UpstreamEntry::new(upstream_cfg.clone()) } }) .collect(); From 4c344b70e1269382b3b15ae26593c555d0eb5486 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Mar 2026 15:49:20 +0000 Subject: [PATCH 04/13] Add priority balance mode with per-upstream priority and online failover Co-authored-by: MejiroRina <70424266+MejiroRina@users.noreply.github.com> Agent-Logs-Url: https://github.com/Team-Haruki/http-proxy-lb/sessions/b98dfa76-d6b3-429d-bd68-504f22a80536 --- README.md | 6 +- config.example.yaml | 7 +- src/config.rs | 16 +++++ src/upstream.rs | 157 +++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 182 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 7feecac..32fcc63 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ A high-availability HTTP relay proxy with upstream load balancing, health checki |---|---| | **Round-robin load balancing** | Weighted round-robin across all online upstream proxies | | **Best-selection mode** | Always routes to the upstream with the lowest latency + connection score | +| **Priority mode** | Routes to the highest-priority online upstream and automatically falls back when it is offline | | **Passive health detection** | Any upstream that fails to connect or respond is immediately marked offline and excluded from routing | | **Active health checking** | Background task periodically probes offline upstreams via TCP connect; marks them online once they recover | | **Automatic failover** | Failed requests are retried on a different upstream (up to `min(pool_size, 3)` retries) | @@ -41,7 +42,7 @@ Set `RUST_LOG=debug` for verbose logging. # Local address to listen on listen: "127.0.0.1:8080" -# Load-balancing mode: round_robin | best +# Load-balancing mode: round_robin | best | priority mode: round_robin # How often to re-read the config file (seconds). 0 = disabled. @@ -54,9 +55,11 @@ health_check: upstream: - url: "http://proxy1.example.com:8080" weight: 1 + priority: 10 - url: "http://proxy2.example.com:8080" weight: 2 + priority: 20 username: "user" password: "secret" ``` @@ -67,6 +70,7 @@ upstream: * **`best`** — Selects the online upstream with the lowest *score*, where: `score = latency_ema_ms + active_connections × 50` Latency is an exponential moving average (α = 0.25) of observed response times. +* **`priority`** — Selects the online upstream with the lowest `priority` value. When the current highest-priority upstream becomes offline, routing automatically switches to the next online priority level; recovered upstreams are reused after active health checks mark them online. ### Hot reload diff --git a/config.example.yaml b/config.example.yaml index 468dc50..0ef7b1d 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -4,10 +4,11 @@ # Local address and port to listen on listen: "127.0.0.1:8080" -# Load-balancing mode: round_robin | best +# Load-balancing mode: round_robin | best | priority # round_robin — weighted round-robin across all online upstreams # best — select the online upstream with the lowest combined score # (latency EMA + active-connection penalty) +# priority — select the online upstream with the lowest priority value mode: round_robin # Hot-reload: re-read config file every N seconds (0 = disabled) @@ -23,16 +24,20 @@ health_check: # Upstream proxy pool # url — http://host:port (only HTTP proxies supported) # weight — relative weight for round_robin (default: 1) +# priority — priority for priority mode (lower = higher priority, default: 100) # username — optional HTTP proxy Basic-Auth username # password — optional HTTP proxy Basic-Auth password upstream: - url: "http://proxy1.example.com:8080" weight: 1 + priority: 10 - url: "http://proxy2.example.com:8080" weight: 2 + priority: 20 username: "user" password: "secret" - url: "http://proxy3.example.com:3128" weight: 1 + priority: 30 diff --git a/src/config.rs b/src/config.rs index 923963f..6adb8b7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -46,6 +46,8 @@ pub enum BalanceMode { RoundRobin, /// Pick the online upstream with the best (lowest) combined score Best, + /// Pick the online upstream with the highest priority (lowest number) + Priority, } // --------------------------------------------------------------------------- @@ -89,6 +91,9 @@ pub struct UpstreamConfig { /// Relative weight for round_robin selection (default: 1) #[serde(default = "default_weight")] pub weight: u32, + /// Priority for priority-mode selection (lower value = higher priority) + #[serde(default = "default_priority")] + pub priority: u32, /// Optional HTTP proxy Basic-Auth username pub username: Option, /// Optional HTTP proxy Basic-Auth password @@ -99,6 +104,10 @@ fn default_weight() -> u32 { 1 } +fn default_priority() -> u32 { + 100 +} + impl UpstreamConfig { /// Returns the `Proxy-Authorization: Basic …` header value, if auth is configured. pub fn proxy_auth_header(&self) -> Option { @@ -160,6 +169,7 @@ mod tests { let cfg = UpstreamConfig { url: "http://proxy.example.com:8080".to_string(), weight: 1, + priority: 100, username: None, password: None, }; @@ -173,6 +183,7 @@ mod tests { let cfg = UpstreamConfig { url: "http://proxy.example.com".to_string(), weight: 1, + priority: 100, username: None, password: None, }; @@ -186,6 +197,7 @@ mod tests { let cfg = UpstreamConfig { url: "http://p:1".to_string(), weight: 1, + priority: 100, username: Some("alice".to_string()), password: Some("s3cr3t".to_string()), }; @@ -202,6 +214,7 @@ mod tests { let cfg = UpstreamConfig { url: "http://p:1".to_string(), weight: 1, + priority: 100, username: None, password: None, }; @@ -216,6 +229,7 @@ mode: best upstream: - url: "http://a:1" weight: 2 + priority: 10 - url: "http://b:2" "#; let cfg: Config = yaml_serde::from_str(yaml).unwrap(); @@ -223,7 +237,9 @@ upstream: assert_eq!(cfg.mode, BalanceMode::Best); assert_eq!(cfg.upstream.len(), 2); assert_eq!(cfg.upstream[0].weight, 2); + assert_eq!(cfg.upstream[0].priority, 10); assert_eq!(cfg.upstream[1].weight, 1); // default + assert_eq!(cfg.upstream[1].priority, 100); // default } #[test] diff --git a/src/upstream.rs b/src/upstream.rs index 389c0ef..89729c0 100644 --- a/src/upstream.rs +++ b/src/upstream.rs @@ -46,6 +46,17 @@ impl UpstreamEntry { }) } + fn from_existing(existing: &Self, config: UpstreamConfig) -> Arc { + Arc::new(Self { + config, + state: AtomicU8::new(existing.state.load(Ordering::Acquire)), + active_conns: AtomicI64::new(existing.active_conns.load(Ordering::Relaxed)), + latency_ema_ms: AtomicU64::new(existing.latency_ema_ms.load(Ordering::Relaxed)), + consec_failures: AtomicU32::new(existing.consec_failures.load(Ordering::Relaxed)), + last_state_change: Mutex::new(Instant::now()), + }) + } + // ---- state accessors --------------------------------------------------- pub fn is_online(&self) -> bool { @@ -157,6 +168,7 @@ impl UpstreamPool { match mode { BalanceMode::RoundRobin => self.select_rr(exclude), BalanceMode::Best => self.select_best(exclude), + BalanceMode::Priority => self.select_priority(exclude), } } @@ -187,6 +199,16 @@ impl UpstreamPool { .map(|(i, e)| (i, Arc::clone(e))) } + fn select_priority(&self, exclude: &[usize]) -> Option<(usize, Arc)> { + let entries = self.entries.lock(); + entries + .iter() + .enumerate() + .filter(|(i, e)| e.is_online() && !exclude.contains(i)) + .min_by_key(|(i, e)| (e.config.priority, *i)) + .map(|(i, e)| (i, Arc::clone(e))) + } + // ---- queries ----------------------------------------------------------- pub fn len(&self) -> usize { @@ -216,9 +238,9 @@ impl UpstreamPool { .map(|upstream_cfg| { // Reuse existing entry if URL matches (preserves state/stats). if let Some(existing) = entries.iter().find(|e| e.config.url == upstream_cfg.url) { - // State and stats are preserved; config reference is kept as-is. + // State/stats are preserved while config (weight/auth/priority) is refreshed. debug!(url = %upstream_cfg.url, "reusing existing upstream entry"); - Arc::clone(existing) + UpstreamEntry::from_existing(existing, upstream_cfg.clone()) } else { info!(url = %upstream_cfg.url, "adding new upstream"); UpstreamEntry::new(upstream_cfg.clone()) @@ -258,6 +280,7 @@ mod tests { .map(|u| UpstreamConfig { url: u.to_string(), weight: 1, + priority: 100, username: None, password: None, }) @@ -351,18 +374,21 @@ mod tests { UpstreamConfig { url: "http://a:1".to_string(), weight: 1, + priority: 100, username: None, password: None, }, UpstreamConfig { url: "http://b:2".to_string(), weight: 1, + priority: 100, username: None, password: None, }, UpstreamConfig { url: "http://c:3".to_string(), weight: 1, + priority: 100, username: None, password: None, }, @@ -399,12 +425,14 @@ mod tests { UpstreamConfig { url: "http://a:1".to_string(), weight: 1, + priority: 100, username: None, password: None, }, UpstreamConfig { url: "http://b:2".to_string(), weight: 3, + priority: 100, username: None, password: None, }, @@ -426,4 +454,129 @@ mod tests { "expected ~3:1 ratio, got a={a_count} b={b_count} (ratio={ratio:.2})" ); } + + #[test] + fn priority_mode_prefers_lowest_priority_value() { + let cfg = Config { + listen: "127.0.0.1:8080".to_string(), + mode: BalanceMode::Priority, + reload_interval_secs: 0, + health_check: HealthCheckConfig::default(), + upstream: vec![ + UpstreamConfig { + url: "http://a:1".to_string(), + weight: 1, + priority: 50, + username: None, + password: None, + }, + UpstreamConfig { + url: "http://b:2".to_string(), + weight: 1, + priority: 10, + username: None, + password: None, + }, + UpstreamConfig { + url: "http://c:3".to_string(), + weight: 1, + priority: 30, + username: None, + password: None, + }, + ], + }; + let pool = UpstreamPool::from_config(&cfg); + let (idx, entry) = pool.select(BalanceMode::Priority, &[]).unwrap(); + assert_eq!(idx, 1); + assert_eq!(entry.config.url, "http://b:2"); + } + + #[test] + fn priority_mode_fallbacks_to_next_online_priority() { + let cfg = Config { + listen: "127.0.0.1:8080".to_string(), + mode: BalanceMode::Priority, + reload_interval_secs: 0, + health_check: HealthCheckConfig::default(), + upstream: vec![ + UpstreamConfig { + url: "http://a:1".to_string(), + weight: 1, + priority: 1, + username: None, + password: None, + }, + UpstreamConfig { + url: "http://b:2".to_string(), + weight: 1, + priority: 2, + username: None, + password: None, + }, + ], + }; + let pool = UpstreamPool::from_config(&cfg); + let (_, first) = pool.select(BalanceMode::Priority, &[]).unwrap(); + assert_eq!(first.config.url, "http://a:1"); + first.mark_offline(); + + let (_, second) = pool.select(BalanceMode::Priority, &[]).unwrap(); + assert_eq!(second.config.url, "http://b:2"); + } + + #[test] + fn reload_updates_priority_for_existing_upstream() { + let cfg = Config { + listen: "127.0.0.1:8080".to_string(), + mode: BalanceMode::Priority, + reload_interval_secs: 0, + health_check: HealthCheckConfig::default(), + upstream: vec![ + UpstreamConfig { + url: "http://a:1".to_string(), + weight: 1, + priority: 100, + username: None, + password: None, + }, + UpstreamConfig { + url: "http://b:2".to_string(), + weight: 1, + priority: 200, + username: None, + password: None, + }, + ], + }; + let pool = UpstreamPool::from_config(&cfg); + let (_, first_before) = pool.select(BalanceMode::Priority, &[]).unwrap(); + assert_eq!(first_before.config.url, "http://a:1"); + + let reloaded = Config { + listen: "127.0.0.1:8080".to_string(), + mode: BalanceMode::Priority, + reload_interval_secs: 0, + health_check: HealthCheckConfig::default(), + upstream: vec![ + UpstreamConfig { + url: "http://a:1".to_string(), + weight: 1, + priority: 300, + username: None, + password: None, + }, + UpstreamConfig { + url: "http://b:2".to_string(), + weight: 1, + priority: 50, + username: None, + password: None, + }, + ], + }; + pool.reload(&reloaded); + let (_, first_after) = pool.select(BalanceMode::Priority, &[]).unwrap(); + assert_eq!(first_after.config.url, "http://b:2"); + } } From be4ff1e13d00057a6f55d316db920c03c0b3766d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Mar 2026 15:52:07 +0000 Subject: [PATCH 05/13] Address review nits for priority mode reload timestamp and test naming Co-authored-by: MejiroRina <70424266+MejiroRina@users.noreply.github.com> Agent-Logs-Url: https://github.com/Team-Haruki/http-proxy-lb/sessions/b98dfa76-d6b3-429d-bd68-504f22a80536 --- src/upstream.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/upstream.rs b/src/upstream.rs index 89729c0..e2bd7d8 100644 --- a/src/upstream.rs +++ b/src/upstream.rs @@ -53,7 +53,7 @@ impl UpstreamEntry { active_conns: AtomicI64::new(existing.active_conns.load(Ordering::Relaxed)), latency_ema_ms: AtomicU64::new(existing.latency_ema_ms.load(Ordering::Relaxed)), consec_failures: AtomicU32::new(existing.consec_failures.load(Ordering::Relaxed)), - last_state_change: Mutex::new(Instant::now()), + last_state_change: Mutex::new(*existing.last_state_change.lock()), }) } @@ -493,7 +493,7 @@ mod tests { } #[test] - fn priority_mode_fallbacks_to_next_online_priority() { + fn priority_mode_falls_back_to_next_online_priority() { let cfg = Config { listen: "127.0.0.1:8080".to_string(), mode: BalanceMode::Priority, From c1ddd514eda82323b4153239c9ddc74b940323c8 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Mar 2026 16:03:00 +0000 Subject: [PATCH 06/13] Use captive generate_204 HTTP probe for active upstream health checks Co-authored-by: MejiroRina <70424266+MejiroRina@users.noreply.github.com> Agent-Logs-Url: https://github.com/Team-Haruki/http-proxy-lb/sessions/796359e3-233b-45ca-b09b-fd596935b478 --- README.md | 2 +- src/health.rs | 149 +++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 141 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 32fcc63..c0eb971 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ A high-availability HTTP relay proxy with upstream load balancing, health checki | **Best-selection mode** | Always routes to the upstream with the lowest latency + connection score | | **Priority mode** | Routes to the highest-priority online upstream and automatically falls back when it is offline | | **Passive health detection** | Any upstream that fails to connect or respond is immediately marked offline and excluded from routing | -| **Active health checking** | Background task periodically probes offline upstreams via TCP connect; marks them online once they recover | +| **Active health checking** | Background task probes offline upstreams using a captive-style HTTP probe (`generate_204`) through the proxy; marks them online on HTTP 204 | | **Automatic failover** | Failed requests are retried on a different upstream (up to `min(pool_size, 3)` retries) | | **HTTP/1.1 keep-alive** | Client connections are reused across requests | | **HTTPS tunneling** | `CONNECT` method is fully supported — traffic is tunneled through the upstream proxy | diff --git a/src/health.rs b/src/health.rs index cb7932a..523c176 100644 --- a/src/health.rs +++ b/src/health.rs @@ -1,17 +1,24 @@ use std::sync::Arc; use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::time::timeout; use tracing::{debug, info}; use crate::upstream::UpstreamPool; +const CAPTIVE_PROBE_URL: &str = "http://connectivitycheck.gstatic.com/generate_204"; +const CAPTIVE_PROBE_HOST: &str = "connectivitycheck.gstatic.com"; +// Enough to capture the status line and a few headers for probe validation. +const PROBE_RESPONSE_BUFFER_SIZE: usize = 256; + /// Runs the active health-checker in the background. /// /// Every `interval_secs` seconds it iterates over every **offline** upstream, -/// attempts a TCP-connect with `timeout_secs` deadline, and marks the upstream -/// **online** if the connection succeeds. +/// sends a captive-style HTTP probe request through the upstream proxy with +/// `timeout_secs` deadline, and marks the upstream **online** if it returns +/// HTTP 204. pub async fn run_health_checker(pool: Arc, interval_secs: u64, timeout_secs: u64) { let interval = Duration::from_secs(interval_secs); let probe_timeout = Duration::from_secs(timeout_secs); @@ -45,16 +52,16 @@ async fn probe_offline(pool: &Arc, probe_timeout: Duration) { return; } }; - match timeout(t, TcpStream::connect(&addr)).await { - Ok(Ok(_)) => { - info!(upstream = %entry.config.url, "health check OK — marking online"); + match probe_proxy_by_captive_http(&entry, &addr, t).await { + Ok(true) => { + info!(upstream = %entry.config.url, "health check OK (HTTP 204) — marking online"); entry.mark_online(); } - Ok(Err(e)) => { - debug!(upstream = %entry.config.url, error = %e, "health check failed"); + Ok(false) => { + debug!(upstream = %entry.config.url, "health check got non-204 response"); } - Err(_) => { - debug!(upstream = %entry.config.url, "health check timed out"); + Err(e) => { + debug!(upstream = %entry.config.url, error = %e, "health check failed"); } } }) @@ -65,3 +72,127 @@ async fn probe_offline(pool: &Arc, probe_timeout: Duration) { let _ = task.await; } } + +async fn probe_proxy_by_captive_http( + entry: &Arc, + addr: &str, + probe_timeout: Duration, +) -> Result { + let mut stream = timeout(probe_timeout, TcpStream::connect(addr)) + .await + .map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::TimedOut, + "connection to upstream proxy timed out", + ) + })??; + + let mut req = format!( + "GET {CAPTIVE_PROBE_URL} HTTP/1.1\r\nHost: {CAPTIVE_PROBE_HOST}\r\nConnection: close\r\nUser-Agent: http-proxy-lb-healthcheck/1.0\r\n" + ); + if let Some(auth) = entry.config.proxy_auth_header() { + req.push_str(&format!("Proxy-Authorization: {auth}\r\n")); + } + req.push_str("\r\n"); + + timeout(probe_timeout, stream.write_all(req.as_bytes())) + .await + .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "write timeout"))??; + + let mut buf = [0u8; PROBE_RESPONSE_BUFFER_SIZE]; + let n = timeout(probe_timeout, stream.read(&mut buf)) + .await + .map_err(|_| std::io::Error::new(std::io::ErrorKind::TimedOut, "read timeout"))??; + if n == 0 { + // Proxy closed connection before returning any HTTP response. + return Ok(false); + } + + let head = String::from_utf8_lossy(&buf[..n]); + let Some(status_line) = head.lines().next() else { + // Malformed HTTP response: no status line. + return Ok(false); + }; + Ok(status_line.starts_with("HTTP/1.1 204") || status_line.starts_with("HTTP/1.0 204")) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{BalanceMode, Config, HealthCheckConfig, UpstreamConfig}; + use crate::upstream::UpstreamPool; + use tokio::net::TcpListener; + + #[tokio::test] + async fn captive_204_marks_offline_upstream_online() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + tokio::spawn(async move { + let (mut sock, _) = listener.accept().await.unwrap(); + let mut buf = [0u8; 1024]; + let n = sock.read(&mut buf).await.unwrap(); + let req = String::from_utf8_lossy(&buf[..n]); + assert!(req.contains("GET http://connectivitycheck.gstatic.com/generate_204 HTTP/1.1")); + let resp = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; + sock.write_all(resp).await.unwrap(); + }); + + let cfg = Config { + listen: "127.0.0.1:8080".to_string(), + mode: BalanceMode::RoundRobin, + reload_interval_secs: 0, + health_check: HealthCheckConfig::default(), + upstream: vec![UpstreamConfig { + url: format!("http://127.0.0.1:{port}"), + weight: 1, + priority: 100, + username: None, + password: None, + }], + }; + let pool = UpstreamPool::from_config(&cfg); + let (_, entry) = pool.select(BalanceMode::RoundRobin, &[]).unwrap(); + entry.mark_offline(); + assert!(!entry.is_online()); + + probe_offline(&pool, Duration::from_secs(2)).await; + assert!(entry.is_online()); + } + + #[tokio::test] + async fn non_204_keeps_upstream_offline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + + tokio::spawn(async move { + let (mut sock, _) = listener.accept().await.unwrap(); + let mut buf = [0u8; 1024]; + let _ = sock.read(&mut buf).await.unwrap(); + let resp = + b"HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; + sock.write_all(resp).await.unwrap(); + }); + + let cfg = Config { + listen: "127.0.0.1:8080".to_string(), + mode: BalanceMode::RoundRobin, + reload_interval_secs: 0, + health_check: HealthCheckConfig::default(), + upstream: vec![UpstreamConfig { + url: format!("http://127.0.0.1:{port}"), + weight: 1, + priority: 100, + username: None, + password: None, + }], + }; + let pool = UpstreamPool::from_config(&cfg); + let (_, entry) = pool.select(BalanceMode::RoundRobin, &[]).unwrap(); + entry.mark_offline(); + assert!(!entry.is_online()); + + probe_offline(&pool, Duration::from_secs(2)).await; + assert!(!entry.is_online()); + } +} From 8c8b8ccbe544fdacc468359c187e78ca07ef5f93 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Mar 2026 16:17:34 +0000 Subject: [PATCH 07/13] Add domain blacklist/whitelist direct-proxy routing policy Co-authored-by: MejiroRina <70424266+MejiroRina@users.noreply.github.com> Agent-Logs-Url: https://github.com/Team-Haruki/http-proxy-lb/sessions/7142be7c-21ed-4f22-a713-fb8bfd12f8d5 --- README.md | 13 +++ config.example.yaml | 8 ++ src/config.rs | 56 +++++++++++ src/health.rs | 6 +- src/main.rs | 4 +- src/proxy.rs | 226 ++++++++++++++++++++++++++++++++++++++++++-- src/upstream.rs | 11 ++- 7 files changed, 313 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index c0eb971..f240238 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ A high-availability HTTP relay proxy with upstream load balancing, health checki | **HTTP/1.1 keep-alive** | Client connections are reused across requests | | **HTTPS tunneling** | `CONNECT` method is fully supported — traffic is tunneled through the upstream proxy | | **Upstream authentication** | Per-upstream HTTP Proxy Basic Auth (`Proxy-Authorization`) | +| **Domain blacklist/whitelist routing** | Supports selective direct/proxy routing by domain (`domain_policy`) | | **Hot config reload** | Config file is re-read every `reload_interval_secs`; new/removed upstreams are applied without restart | ## Installation @@ -52,6 +53,10 @@ health_check: interval_secs: 30 # probe interval for offline upstreams timeout_secs: 5 # TCP-connect timeout per probe +domain_policy: + mode: off # off | blacklist | whitelist + domains: [] # blacklist: direct; whitelist: proxy + upstream: - url: "http://proxy1.example.com:8080" weight: 1 @@ -76,6 +81,14 @@ upstream: Edit `config.yaml` while the proxy is running. Within `reload_interval_secs` seconds the proxy will pick up the new upstream list. Existing upstreams (matched by URL) keep their online/offline state and statistics. +### Domain policy + +`domain_policy` controls whether specific domains should use upstream proxy or direct connection: + +* **`off`** (default): all domains use upstream proxy. +* **`blacklist`**: domains listed in `domains` bypass upstream proxy and connect directly; other domains still use upstream proxy. +* **`whitelist`**: only domains listed in `domains` use upstream proxy; other domains connect directly. + ## Usage with curl ```bash diff --git a/config.example.yaml b/config.example.yaml index 0ef7b1d..89a3a5b 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -21,6 +21,14 @@ health_check: # TCP-connect timeout when probing (seconds) timeout_secs: 5 +# Optional domain routing policy: +# off — all domains use upstream proxy (default) +# blacklist — listed domains go direct (bypass upstream proxy) +# whitelist — listed domains use upstream proxy; others go direct +domain_policy: + mode: off + domains: [] + # Upstream proxy pool # url — http://host:port (only HTTP proxies supported) # weight — relative weight for round_robin (default: 1) diff --git a/src/config.rs b/src/config.rs index 6adb8b7..42edd38 100644 --- a/src/config.rs +++ b/src/config.rs @@ -25,6 +25,10 @@ pub struct Config { #[serde(default)] pub health_check: HealthCheckConfig, + /// Domain-based direct/proxy routing policy + #[serde(default)] + pub domain_policy: DomainPolicyConfig, + /// Upstream proxy list #[serde(default)] pub upstream: Vec, @@ -50,6 +54,39 @@ pub enum BalanceMode { Priority, } +// --------------------------------------------------------------------------- +// Domain policy config +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum DomainPolicyMode { + /// Always use upstream proxy (default) + #[default] + Off, + /// Listed domains go direct, others use upstream proxy + Blacklist, + /// Listed domains use upstream proxy, others go direct + Whitelist, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DomainPolicyConfig { + #[serde(default)] + pub mode: DomainPolicyMode, + #[serde(default)] + pub domains: Vec, +} + +impl Default for DomainPolicyConfig { + fn default() -> Self { + Self { + mode: DomainPolicyMode::Off, + domains: Vec::new(), + } + } +} + // --------------------------------------------------------------------------- // Health-check config // --------------------------------------------------------------------------- @@ -253,5 +290,24 @@ upstream: [] assert_eq!(cfg.reload_interval_secs, 60); assert_eq!(cfg.health_check.interval_secs, 30); assert_eq!(cfg.health_check.timeout_secs, 5); + assert_eq!(cfg.domain_policy.mode, DomainPolicyMode::Off); + assert!(cfg.domain_policy.domains.is_empty()); + } + + #[test] + fn test_parse_yaml_domain_policy() { + let yaml = r#" +listen: "0.0.0.0:8080" +domain_policy: + mode: blacklist + domains: + - "example.com" + - "internal.local" +upstream: [] +"#; + let cfg: Config = yaml_serde::from_str(yaml).unwrap(); + assert_eq!(cfg.domain_policy.mode, DomainPolicyMode::Blacklist); + assert_eq!(cfg.domain_policy.domains.len(), 2); + assert_eq!(cfg.domain_policy.domains[0], "example.com"); } } diff --git a/src/health.rs b/src/health.rs index 523c176..1cda8e7 100644 --- a/src/health.rs +++ b/src/health.rs @@ -119,7 +119,9 @@ async fn probe_proxy_by_captive_http( #[cfg(test)] mod tests { use super::*; - use crate::config::{BalanceMode, Config, HealthCheckConfig, UpstreamConfig}; + use crate::config::{ + BalanceMode, Config, DomainPolicyConfig, HealthCheckConfig, UpstreamConfig, + }; use crate::upstream::UpstreamPool; use tokio::net::TcpListener; @@ -143,6 +145,7 @@ mod tests { mode: BalanceMode::RoundRobin, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), + domain_policy: DomainPolicyConfig::default(), upstream: vec![UpstreamConfig { url: format!("http://127.0.0.1:{port}"), weight: 1, @@ -179,6 +182,7 @@ mod tests { mode: BalanceMode::RoundRobin, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), + domain_policy: DomainPolicyConfig::default(), upstream: vec![UpstreamConfig { url: format!("http://127.0.0.1:{port}"), weight: 1, diff --git a/src/main.rs b/src/main.rs index 79a239f..2621fab 100644 --- a/src/main.rs +++ b/src/main.rs @@ -56,6 +56,7 @@ async fn main() -> Result<()> { let reload_interval = cfg.reload_interval_secs; let hc_interval = cfg.health_check.interval_secs; let hc_timeout = cfg.health_check.timeout_secs; + let domain_policy = Arc::new(cfg.domain_policy.clone()); // Build upstream pool let pool = UpstreamPool::from_config(&cfg); @@ -89,8 +90,9 @@ async fn main() -> Result<()> { Ok((stream, peer)) => { debug_assert_ne!(peer.port(), 0); let pool = Arc::clone(&pool); + let domain_policy = Arc::clone(&domain_policy); tokio::spawn(async move { - proxy::handle_client(stream, pool, mode).await; + proxy::handle_client(stream, pool, mode, domain_policy).await; }); } Err(e) => { diff --git a/src/proxy.rs b/src/proxy.rs index 5ab5210..07da367 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -15,7 +15,7 @@ use tokio::net::TcpStream; use tokio::time::{timeout, Duration}; use tracing::{debug, warn}; -use crate::config::BalanceMode; +use crate::config::{BalanceMode, DomainPolicyConfig, DomainPolicyMode}; use crate::upstream::UpstreamPool; // --------------------------------------------------------------------------- @@ -33,7 +33,12 @@ const MAX_RETRIES: usize = 3; /// Accept a single client TCP connection and serve HTTP proxy requests on it. /// Supports keep-alive: loops until the client closes or sends `Connection: close`. -pub async fn handle_client(mut client: TcpStream, pool: Arc, mode: BalanceMode) { +pub async fn handle_client( + mut client: TcpStream, + pool: Arc, + mode: BalanceMode, + domain_policy: Arc, +) { let peer = client .peer_addr() .map(|a| a.to_string()) @@ -43,7 +48,7 @@ pub async fn handle_client(mut client: TcpStream, pool: Arc, mode: loop { match read_headers(&mut client).await { Ok(Some(buf)) => { - match dispatch(&mut client, &buf, &pool, mode).await { + match dispatch(&mut client, &buf, &pool, mode, &domain_policy).await { Ok(true) => continue, // keep-alive: read next request Ok(false) => break, Err(e) => { @@ -73,6 +78,7 @@ async fn dispatch( buf: &[u8], pool: &Arc, mode: BalanceMode, + domain_policy: &Arc, ) -> Result { // --- parse request line + headers --- let mut raw_headers = [httparse::EMPTY_HEADER; 96]; @@ -103,15 +109,19 @@ async fn dispatch( // --- CONNECT (HTTPS tunnel) --- if method.eq_ignore_ascii_case("CONNECT") { let (host, port) = parse_connect_target(&path)?; - return handle_connect(client, &host, port, pool, mode) + let use_proxy = should_use_proxy(&host, domain_policy); + return handle_connect(client, &host, port, pool, mode, use_proxy) .await .map(|_| false); } // --- Plain HTTP --- + let (target_host, target_port) = parse_http_target(&path, headers)?; + let use_proxy = should_use_proxy(&target_host, domain_policy); handle_http( client, &method, + &path, buf, body_offset, req_content_length, @@ -119,6 +129,8 @@ async fn dispatch( client_keep_alive, pool, mode, + use_proxy, + (&target_host, target_port), ) .await } @@ -133,7 +145,17 @@ async fn handle_connect( port: u16, pool: &Arc, mode: BalanceMode, + use_proxy: bool, ) -> Result<()> { + if !use_proxy { + let addr = format!("{host}:{port}"); + let up_stream = connect_target(&addr).await?; + client + .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n") + .await?; + return tunnel(client, up_stream).await; + } + let mut tried: Vec = Vec::new(); let max_retries = pool.len().min(MAX_RETRIES); @@ -235,6 +257,7 @@ async fn tunnel(client: &mut TcpStream, mut upstream: TcpStream) -> Result<()> { async fn handle_http( client: &mut TcpStream, method: &str, + path: &str, req_buf: &[u8], // raw bytes: headers + any already-buffered body bytes body_offset: usize, // where headers end within req_buf req_content_length: Option, @@ -242,7 +265,48 @@ async fn handle_http( client_keep_alive: bool, pool: &Arc, mode: BalanceMode, + use_proxy: bool, + direct_target: (&str, u16), ) -> Result { + if !use_proxy { + let addr = format!("{}:{}", direct_target.0, direct_target.1); + let mut direct_stream = match connect_target(&addr).await { + Ok(s) => s, + Err(_) => { + let _ = write_status(client, 502, "Bad Gateway").await; + return Ok(false); + } + }; + let no_proxy_auth = None; + let direct_original_path = Some(path); + let fwd_headers = + rewrite_request_headers(req_buf, body_offset, no_proxy_auth, direct_original_path); + direct_stream.write_all(&fwd_headers).await?; + + let already_in_buf = req_buf.len().saturating_sub(body_offset); + if already_in_buf > 0 { + direct_stream.write_all(&req_buf[body_offset..]).await?; + } + if forward_body( + client, + &mut direct_stream, + req_content_length, + req_is_chunked, + already_in_buf as u64, + ) + .await + .is_err() + { + return Ok(false); + } + + let upstream_keep_alive = match forward_response(&mut direct_stream, client, method).await { + Ok(v) => v, + Err(_) => return Ok(false), + }; + return Ok(client_keep_alive && upstream_keep_alive); + } + let mut tried: Vec = Vec::new(); let max_retries = pool.len().min(MAX_RETRIES); @@ -275,8 +339,12 @@ async fn handle_http( }; // --- Build and send request headers --- - let fwd_headers = - rewrite_request_headers(req_buf, body_offset, upstream.config.proxy_auth_header()); + let fwd_headers = rewrite_request_headers( + req_buf, + body_offset, + upstream.config.proxy_auth_header(), + None, + ); if let Err(e) = up_stream.write_all(&fwd_headers).await { warn!(upstream = %upstream.config.url, error = %e, "HTTP: write headers failed"); upstream.mark_offline(); @@ -527,16 +595,26 @@ async fn read_headers(stream: &mut TcpStream) -> Result>> { /// * Removes `Proxy-Authorization` from the client to avoid leaking /// the client's upstream credentials to the next hop. /// * Injects our upstream's `Proxy-Authorization` if required. -fn rewrite_request_headers(raw: &[u8], body_offset: usize, proxy_auth: Option) -> Vec { +fn rewrite_request_headers( + raw: &[u8], + body_offset: usize, + proxy_auth: Option, + original_path: Option<&str>, +) -> Vec { let header_bytes = &raw[..body_offset]; let header_str = String::from_utf8_lossy(header_bytes); let mut out = String::with_capacity(header_bytes.len() + 64); let mut first_line = true; + let is_direct_mode = original_path.is_some(); for line in header_str.split("\r\n") { if first_line { - out.push_str(line); + if let Some(path) = original_path { + out.push_str(&rewrite_request_line_for_direct(line, path)); + } else { + out.push_str(line); + } out.push_str("\r\n"); first_line = false; continue; @@ -556,6 +634,9 @@ fn rewrite_request_headers(raw: &[u8], body_offset: usize, proxy_auth: Option Result { // --------------------------------------------------------------------------- async fn connect_upstream(addr: &str) -> Result { + connect_target(addr).await +} + +async fn connect_target(addr: &str) -> Result { timeout(CONNECT_TIMEOUT, TcpStream::connect(addr)) .await .map_err(|_| anyhow!("connection to {addr} timed out"))? @@ -622,6 +707,95 @@ fn get_header<'a>(headers: &'a [httparse::Header<'a>], name: &str) -> Option<&'a }) } +fn should_use_proxy(host: &str, policy: &DomainPolicyConfig) -> bool { + let host_lc = host.to_ascii_lowercase(); + let matched = policy.domains.iter().any(|d| domain_matches(&host_lc, d)); + match policy.mode { + DomainPolicyMode::Off => true, + DomainPolicyMode::Blacklist => !matched, + DomainPolicyMode::Whitelist => matched, + } +} + +fn domain_matches(host_lc: &str, domain: &str) -> bool { + let d = domain.trim().to_ascii_lowercase(); + host_lc == d || host_lc.ends_with(&format!(".{d}")) +} + +fn parse_http_target(path: &str, headers: &[httparse::Header<'_>]) -> Result<(String, u16)> { + if let Some(rest) = path + .strip_prefix("http://") + .or_else(|| path.strip_prefix("https://")) + { + let authority = rest.split('/').next().unwrap_or(rest); + return parse_authority_host_port(authority, 80); + } + let host = get_header(headers, "host").ok_or_else(|| anyhow!("missing Host header"))?; + parse_authority_host_port(host, 80) +} + +fn parse_authority_host_port(authority: &str, default_port: u16) -> Result<(String, u16)> { + let authority = authority.trim(); + let authority = authority.rsplit('@').next().unwrap_or(authority); + + if let Some(rest) = authority.strip_prefix('[') { + let end = rest + .find(']') + .ok_or_else(|| anyhow!("invalid bracketed authority: {authority}"))?; + let host = &rest[..end]; + if host.is_empty() { + bail!("invalid authority: {authority}"); + } + + let port = match &rest[end + 1..] { + "" => default_port, + suffix if suffix.starts_with(':') => suffix[1..].parse::().unwrap_or(default_port), + _ => bail!("invalid bracketed authority: {authority}"), + }; + return Ok((host.to_string(), port)); + } + + let colon_count = authority.matches(':').count(); + if colon_count > 1 { + // Unbracketed IPv6 literal; treat entire authority as host with default port. + return Ok((authority.to_string(), default_port)); + } + + let mut parts = authority.splitn(2, ':'); + let host = parts + .next() + .filter(|h| !h.is_empty()) + .ok_or_else(|| anyhow!("invalid authority: {authority}"))? + .to_string(); + let port = match parts.next() { + Some("") | None => default_port, + Some(p) => p + .parse::() + .map_err(|_| anyhow!("invalid port in authority: {authority}"))?, + }; + Ok((host, port)) +} + +fn rewrite_request_line_for_direct(line: &str, original_path: &str) -> String { + let mut p = line.split_whitespace(); + let method = p.next().unwrap_or("GET"); + let fallback_target = p.next().unwrap_or(original_path); + let version = p.next().unwrap_or("HTTP/1.1"); + let target = absolute_to_origin_form(original_path).unwrap_or(fallback_target); + format!("{method} {target} {version}") +} + +fn absolute_to_origin_form(path: &str) -> Option<&str> { + let rest = path + .strip_prefix("http://") + .or_else(|| path.strip_prefix("https://"))?; + let slash = rest.find('/').unwrap_or(rest.len()); + if slash == rest.len() { + return Some("/"); + } + Some(&rest[slash..]) +} + fn client_wants_keep_alive(headers: &[httparse::Header<'_>], version: u8) -> bool { match get_header(headers, "connection").map(|v| v.to_ascii_lowercase()) { Some(ref v) if v.contains("close") => false, @@ -643,3 +817,39 @@ async fn write_status(stream: &mut TcpStream, code: u16, msg: &str) -> Result<() stream.write_all(resp.as_bytes()).await?; Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn whitelist_only_uses_proxy_for_listed_domains() { + let p = DomainPolicyConfig { + mode: DomainPolicyMode::Whitelist, + domains: vec!["example.com".to_string()], + }; + assert!(should_use_proxy("api.example.com", &p)); + assert!(!should_use_proxy("google.com", &p)); + } + + #[test] + fn blacklist_skips_proxy_for_listed_domains() { + let p = DomainPolicyConfig { + mode: DomainPolicyMode::Blacklist, + domains: vec!["example.com".to_string()], + }; + assert!(!should_use_proxy("example.com", &p)); + assert!(should_use_proxy("google.com", &p)); + } + + #[test] + fn absolute_uri_is_rewritten_to_origin_form() { + assert_eq!( + rewrite_request_line_for_direct( + "GET http://example.com/a/b?q=1 HTTP/1.1", + "http://example.com/a/b?q=1" + ), + "GET /a/b?q=1 HTTP/1.1" + ); + } +} diff --git a/src/upstream.rs b/src/upstream.rs index e2bd7d8..77cc628 100644 --- a/src/upstream.rs +++ b/src/upstream.rs @@ -267,7 +267,9 @@ impl UpstreamPool { #[cfg(test)] mod tests { use super::*; - use crate::config::{BalanceMode, Config, HealthCheckConfig, UpstreamConfig}; + use crate::config::{ + BalanceMode, Config, DomainPolicyConfig, HealthCheckConfig, UpstreamConfig, + }; fn make_pool(urls: &[&str]) -> Arc { let cfg = Config { @@ -275,6 +277,7 @@ mod tests { mode: BalanceMode::RoundRobin, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), + domain_policy: DomainPolicyConfig::default(), upstream: urls .iter() .map(|u| UpstreamConfig { @@ -370,6 +373,7 @@ mod tests { mode: BalanceMode::RoundRobin, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), + domain_policy: DomainPolicyConfig::default(), upstream: vec![ UpstreamConfig { url: "http://a:1".to_string(), @@ -421,6 +425,7 @@ mod tests { mode: BalanceMode::RoundRobin, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), + domain_policy: DomainPolicyConfig::default(), upstream: vec![ UpstreamConfig { url: "http://a:1".to_string(), @@ -462,6 +467,7 @@ mod tests { mode: BalanceMode::Priority, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), + domain_policy: DomainPolicyConfig::default(), upstream: vec![ UpstreamConfig { url: "http://a:1".to_string(), @@ -499,6 +505,7 @@ mod tests { mode: BalanceMode::Priority, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), + domain_policy: DomainPolicyConfig::default(), upstream: vec![ UpstreamConfig { url: "http://a:1".to_string(), @@ -532,6 +539,7 @@ mod tests { mode: BalanceMode::Priority, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), + domain_policy: DomainPolicyConfig::default(), upstream: vec![ UpstreamConfig { url: "http://a:1".to_string(), @@ -558,6 +566,7 @@ mod tests { mode: BalanceMode::Priority, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), + domain_policy: DomainPolicyConfig::default(), upstream: vec![ UpstreamConfig { url: "http://a:1".to_string(), From cabf28a2a93d45b5244184106170de0a6e40a8b6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Mar 2026 17:00:21 +0000 Subject: [PATCH 08/13] Support domain/suffix expressions and add agent instruction docs Co-authored-by: MejiroRina <70424266+MejiroRina@users.noreply.github.com> Agent-Logs-Url: https://github.com/Team-Haruki/http-proxy-lb/sessions/4c8ca4fb-6fd8-475f-b504-6d64f5f3361f --- AGENTS.md | 25 +++++++++++++++ README.md | 9 +++++- config.example.yaml | 6 ++++ copilot-Instructions.md | 27 ++++++++++++++++ src/proxy.rs | 68 +++++++++++++++++++++++++++++++++++++++-- 5 files changed, 131 insertions(+), 4 deletions(-) create mode 100644 AGENTS.md create mode 100644 copilot-Instructions.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..a89c309 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,25 @@ +# AGENTS.md + +## Purpose + +This repository contains `http-proxy-lb`, a Rust HTTP relay proxy with upstream load balancing, health checks, hot reload, and domain policy routing. + +## Scope + +- Keep code changes minimal and focused. +- Prefer targeted tests first, then full validation. +- Preserve existing behavior unless the task explicitly requires a change. + +## Local validation + +- Run tests: `cargo test` +- Run lint checks: `cargo clippy -- -D warnings` +- Format code: `cargo fmt` + +## Code structure + +- `src/config.rs`: configuration model and YAML loading +- `src/upstream.rs`: upstream entry state + pool selection/reload +- `src/health.rs`: active health checking logic +- `src/proxy.rs`: CONNECT and HTTP forwarding logic +- `src/main.rs`: startup, accept loop, background tasks diff --git a/README.md b/README.md index f240238..97541c7 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ health_check: domain_policy: mode: off # off | blacklist | whitelist - domains: [] # blacklist: direct; whitelist: proxy + domains: [] # supports domain:, suffix:, *., . shorthand upstream: - url: "http://proxy1.example.com:8080" @@ -89,6 +89,13 @@ Edit `config.yaml` while the proxy is running. Within `reload_interval_secs` se * **`blacklist`**: domains listed in `domains` bypass upstream proxy and connect directly; other domains still use upstream proxy. * **`whitelist`**: only domains listed in `domains` use upstream proxy; other domains connect directly. +Expression formats in `domains`: + +* `domain:example.com` — exact domain match +* `suffix:example.com` — suffix match (`example.com` and `*.example.com`) +* `*.example.com` / `.example.com` — suffix shorthand +* `example.com` — backward-compatible exact-or-suffix match + ## Usage with curl ```bash diff --git a/config.example.yaml b/config.example.yaml index 89a3a5b..e30d17c 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -27,6 +27,12 @@ health_check: # whitelist — listed domains use upstream proxy; others go direct domain_policy: mode: off + # expression examples: + # - "domain:example.com" # exact + # - "suffix:example.com" # suffix + # - "*.example.com" # suffix shorthand + # - ".example.com" # suffix shorthand + # - "example.com" # backward-compatible exact-or-suffix domains: [] # Upstream proxy pool diff --git a/copilot-Instructions.md b/copilot-Instructions.md new file mode 100644 index 0000000..60e6e26 --- /dev/null +++ b/copilot-Instructions.md @@ -0,0 +1,27 @@ +# copilot-Instructions.md + +## Project coding notes + +- Language: Rust (edition 2021) +- Runtime: Tokio +- Proxy protocol handling is in `src/proxy.rs` +- Configuration is YAML via `yaml_serde` + +## Development expectations + +1. Make surgical changes that directly address the request. +2. Add/adjust tests for changed behavior. +3. Keep docs aligned with user-facing configuration changes. +4. Validate with: + - `cargo test` + - `cargo clippy -- -D warnings` + +## Domain policy expressions + +`domain_policy.domains` currently supports: + +- `domain:example.com` (exact match) +- `suffix:example.com` (domain suffix match) +- `*.example.com` (suffix shorthand) +- `.example.com` (suffix shorthand) +- `example.com` (backward-compatible exact-or-suffix behavior) diff --git a/src/proxy.rs b/src/proxy.rs index 07da367..cc76f3b 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -717,9 +717,50 @@ fn should_use_proxy(host: &str, policy: &DomainPolicyConfig) -> bool { } } -fn domain_matches(host_lc: &str, domain: &str) -> bool { - let d = domain.trim().to_ascii_lowercase(); - host_lc == d || host_lc.ends_with(&format!(".{d}")) +fn domain_matches(host_lc: &str, domain_expr: &str) -> bool { + let expr = domain_expr.trim().to_ascii_lowercase(); + if expr.is_empty() { + return false; + } + + if let Some(domain) = expr.strip_prefix("domain:") { + if domain.is_empty() { + return false; + } + return host_lc == domain; + } + if let Some(suffix) = expr.strip_prefix("suffix:") { + if suffix.is_empty() { + return false; + } + return host_matches_suffix(host_lc, suffix); + } + if let Some(suffix) = expr.strip_prefix("*.") { + if suffix.is_empty() { + return false; + } + return host_matches_suffix(host_lc, suffix); + } + if let Some(suffix) = expr.strip_prefix('.') { + if suffix.is_empty() { + return false; + } + return host_matches_suffix(host_lc, suffix); + } + + // Backward-compatible default behavior for plain values: + // exact domain OR subdomain suffix. + host_lc == expr || host_matches_suffix(host_lc, &expr) +} + +fn host_matches_suffix(host_lc: &str, suffix: &str) -> bool { + if suffix.is_empty() { + return false; + } + host_lc == suffix + || host_lc + .strip_suffix(suffix) + .is_some_and(|prefix| prefix.ends_with('.')) } fn parse_http_target(path: &str, headers: &[httparse::Header<'_>]) -> Result<(String, u16)> { @@ -852,4 +893,25 @@ mod tests { "GET /a/b?q=1 HTTP/1.1" ); } + + #[test] + fn domain_expression_supports_exact_and_suffix_variants() { + assert!(domain_matches("api.example.com", "suffix:example.com")); + assert!(domain_matches("api.example.com", "*.example.com")); + assert!(domain_matches("api.example.com", ".example.com")); + assert!(domain_matches("example.com", "domain:example.com")); + assert!(!domain_matches("api.example.com", "domain:example.com")); + assert!(domain_matches("example.com", "example.com")); + assert!(domain_matches("api.example.com", "example.com")); + assert!(!domain_matches("example.com", "")); + assert!(!domain_matches("example.com", "domain:")); + assert!(!domain_matches("example.com", "suffix:")); + assert!(!domain_matches("example.com", "*.")); + assert!(!domain_matches("example.com", ".")); + } + + #[test] + fn empty_suffix_does_not_match() { + assert!(!host_matches_suffix("example.com", "")); + } } From 198478899233c8e46ffbdf0f6dec37b18aee71ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=9F=E9=9B=B2=E5=B8=8C=E5=87=AA?= <70424266+MejiroRina@users.noreply.github.com> Date: Sun, 22 Mar 2026 07:06:25 +0800 Subject: [PATCH 09/13] [Feat] More features --- .../copilot-Instructions.md | 0 Dockerfile | 50 ++ README.md | 112 +++++ config.example.yaml | 15 + docker-compose.yml | 37 ++ http-proxy-lb.service | 34 ++ src/admin.rs | 438 ++++++++++++++++++ src/config.rs | 80 ++++ src/health.rs | 29 +- src/main.rs | 249 +++++++++- src/proxy.rs | 69 ++- src/upstream.rs | 55 ++- 12 files changed, 1139 insertions(+), 29 deletions(-) rename copilot-Instructions.md => .github/copilot-Instructions.md (100%) create mode 100644 Dockerfile create mode 100644 docker-compose.yml create mode 100644 http-proxy-lb.service create mode 100644 src/admin.rs diff --git a/copilot-Instructions.md b/.github/copilot-Instructions.md similarity index 100% rename from copilot-Instructions.md rename to .github/copilot-Instructions.md diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..7daf6a2 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,50 @@ +# Build stage +FROM rust:1.75-slim-bookworm AS builder + +WORKDIR /app + +# Install build dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +# Copy source code +COPY Cargo.toml Cargo.lock ./ +COPY src ./src + +# Build release binary +RUN cargo build --release + +# Runtime stage +FROM debian:bookworm-slim + +# Install runtime dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Create non-root user +RUN useradd -r -s /bin/false proxy + +# Copy binary from builder +COPY --from=builder /app/target/release/http-proxy-lb /usr/local/bin/ + +# Create config directory +RUN mkdir -p /etc/http-proxy-lb && chown proxy:proxy /etc/http-proxy-lb + +# Switch to non-root user +USER proxy + +# Default config location +ENV CONFIG_PATH=/etc/http-proxy-lb/config.yaml + +# Expose proxy port and admin port +EXPOSE 8080 9090 + +# Health check using admin endpoint +HEALTHCHECK --interval=30s --timeout=5s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:9090/health || exit 1 + +# Run the proxy +ENTRYPOINT ["http-proxy-lb"] +CMD ["--config", "/etc/http-proxy-lb/config.yaml"] diff --git a/README.md b/README.md index 97541c7..b7ae14f 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,11 @@ A high-availability HTTP relay proxy with upstream load balancing, health checki | **Upstream authentication** | Per-upstream HTTP Proxy Basic Auth (`Proxy-Authorization`) | | **Domain blacklist/whitelist routing** | Supports selective direct/proxy routing by domain (`domain_policy`) | | **Hot config reload** | Config file is re-read every `reload_interval_secs`; new/removed upstreams are applied without restart | +| **Prometheus metrics** | `/metrics` endpoint for monitoring with Prometheus | +| **Admin API** | `/status` JSON endpoint and `/health` health check | +| **Graceful shutdown** | Handles SIGTERM/SIGINT, waits for active connections to complete | +| **Connection limiting** | Optional max concurrent connections and request timeout | +| **Access logging** | Optional structured access logs for each request | ## Installation @@ -27,6 +32,28 @@ cargo build --release # binary is at target/release/http-proxy-lb ``` +### Docker + +```bash +# Build image +docker build -t http-proxy-lb . + +# Run with config file +docker run -d \ + -p 8080:8080 \ + -p 9090:9090 \ + -v $(pwd)/config.yaml:/etc/http-proxy-lb/config.yaml:ro \ + http-proxy-lb +``` + +Or use Docker Compose: + +```bash +cp config.example.yaml config.yaml +# edit config.yaml +docker compose up -d +``` + ## Quick start ```bash @@ -37,22 +64,40 @@ cp config.example.yaml config.yaml Set `RUST_LOG=debug` for verbose logging. +### Validate configuration + +```bash +./target/release/http-proxy-lb --config config.yaml --check +``` + ## Configuration ```yaml # Local address to listen on listen: "127.0.0.1:8080" +# Admin server for /metrics and /status endpoints (optional) +admin_listen: "127.0.0.1:9090" + # Load-balancing mode: round_robin | best | priority mode: round_robin # How often to re-read the config file (seconds). 0 = disabled. reload_interval_secs: 60 +# Enable access logging +access_log: false + health_check: interval_secs: 30 # probe interval for offline upstreams timeout_secs: 5 # TCP-connect timeout per probe +# Resource limits (0 = unlimited) +limits: + max_connections: 0 # max concurrent connections + request_timeout_secs: 0 # request timeout + shutdown_timeout_secs: 30 # graceful shutdown timeout + domain_policy: mode: off # off | blacklist | whitelist domains: [] # supports domain:, suffix:, *., . shorthand @@ -96,6 +141,51 @@ Expression formats in `domains`: * `*.example.com` / `.example.com` — suffix shorthand * `example.com` — backward-compatible exact-or-suffix match +## Monitoring + +When `admin_listen` is configured, the following endpoints are available: + +### GET /metrics + +Prometheus-compatible metrics: + +``` +http_proxy_lb_uptime_seconds 3600 +http_proxy_lb_requests_total 150000 +http_proxy_lb_requests_success 149500 +http_proxy_lb_requests_failed 500 +http_proxy_lb_active_connections 42 +http_proxy_lb_upstream_online{url="http://proxy1:8080"} 1 +http_proxy_lb_upstream_latency_ms{url="http://proxy1:8080"} 23 +``` + +### GET /status + +JSON status of all upstreams: + +```json +{ + "uptime_seconds": 3600, + "requests_total": 150000, + "requests_success": 149500, + "active_connections": 42, + "upstreams": [ + { + "url": "http://proxy1:8080", + "online": true, + "active_connections": 20, + "latency_ms": 23, + "weight": 1, + "priority": 10 + } + ] +} +``` + +### GET /health + +Simple health check endpoint (returns `{"status":"ok"}`). + ## Usage with curl ```bash @@ -124,6 +214,28 @@ Client ──► [http-proxy-lb listener] Each client connection is handled in its own Tokio task. Upstream connections are established fresh per request (no upstream connection pooling). +## Deployment + +### Systemd + +Copy the binary and service file: + +```bash +sudo cp target/release/http-proxy-lb /usr/local/bin/ +sudo cp http-proxy-lb.service /etc/systemd/system/ +sudo mkdir -p /etc/http-proxy-lb +sudo cp config.yaml /etc/http-proxy-lb/ +sudo useradd -r -s /bin/false proxy +sudo systemctl daemon-reload +sudo systemctl enable --now http-proxy-lb +``` + +### Docker Compose with Prometheus + +```bash +docker compose --profile monitoring up -d +``` + ## License MIT diff --git a/config.example.yaml b/config.example.yaml index e30d17c..8b6876b 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -4,6 +4,9 @@ # Local address and port to listen on listen: "127.0.0.1:8080" +# Optional: Admin server address for /metrics (Prometheus) and /status (JSON) endpoints +# admin_listen: "127.0.0.1:9090" + # Load-balancing mode: round_robin | best | priority # round_robin — weighted round-robin across all online upstreams # best — select the online upstream with the lowest combined score @@ -14,6 +17,9 @@ mode: round_robin # Hot-reload: re-read config file every N seconds (0 = disabled) reload_interval_secs: 60 +# Enable access logging (logs each request with method, target, status, latency) +access_log: false + # Active health-check settings health_check: # How often to probe offline upstreams (seconds) @@ -21,6 +27,15 @@ health_check: # TCP-connect timeout when probing (seconds) timeout_secs: 5 +# Resource limits (all optional, 0 = unlimited) +limits: + # Maximum concurrent client connections (0 = unlimited) + max_connections: 0 + # Request timeout in seconds (0 = unlimited) + request_timeout_secs: 0 + # Graceful shutdown timeout in seconds + shutdown_timeout_secs: 30 + # Optional domain routing policy: # off — all domains use upstream proxy (default) # blacklist — listed domains go direct (bypass upstream proxy) diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..0d7c2c0 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,37 @@ +services: + http-proxy-lb: + build: . + container_name: http-proxy-lb + restart: unless-stopped + ports: + - "8080:8080" # Proxy port + - "9090:9090" # Admin/metrics port + volumes: + - ./config.yaml:/etc/http-proxy-lb/config.yaml:ro + environment: + - RUST_LOG=info + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9090/health"] + interval: 30s + timeout: 5s + retries: 3 + start_period: 5s + + # Optional: Prometheus for metrics collection + prometheus: + image: prom/prometheus:latest + container_name: prometheus + restart: unless-stopped + ports: + - "9091:9090" + volumes: + - ./prometheus.yml:/etc/prometheus/prometheus.yml:ro + - prometheus_data:/prometheus + command: + - '--config.file=/etc/prometheus/prometheus.yml' + - '--storage.tsdb.path=/prometheus' + profiles: + - monitoring + +volumes: + prometheus_data: diff --git a/http-proxy-lb.service b/http-proxy-lb.service new file mode 100644 index 0000000..b3d746f --- /dev/null +++ b/http-proxy-lb.service @@ -0,0 +1,34 @@ +[Unit] +Description=HTTP Proxy Load Balancer +Documentation=https://github.com/Team-Haruki/http-proxy-lb +After=network-online.target +Wants=network-online.target + +[Service] +Type=simple +User=proxy +Group=proxy +ExecStart=/usr/local/bin/http-proxy-lb --config /etc/http-proxy-lb/config.yaml +Restart=on-failure +RestartSec=5 + +# Logging +StandardOutput=journal +StandardError=journal +SyslogIdentifier=http-proxy-lb + +# Security hardening +NoNewPrivileges=true +ProtectSystem=strict +ProtectHome=true +PrivateTmp=true +ReadOnlyPaths=/etc/http-proxy-lb + +# Resource limits (adjust as needed) +LimitNOFILE=65535 + +# Environment +Environment=RUST_LOG=info + +[Install] +WantedBy=multi-user.target diff --git a/src/admin.rs b/src/admin.rs new file mode 100644 index 0000000..dce8bfa --- /dev/null +++ b/src/admin.rs @@ -0,0 +1,438 @@ +//! Admin server providing metrics and status endpoints. +//! +//! Endpoints: +//! * `GET /metrics` — Prometheus-compatible metrics +//! * `GET /status` — JSON status of all upstreams +//! * `GET /health` — Simple health check (returns 200 OK) + +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::Instant; + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tracing::{debug, error, info}; + +use crate::upstream::UpstreamPool; + +// --------------------------------------------------------------------------- +// Global metrics +// --------------------------------------------------------------------------- + +/// Global metrics counters shared across the application. +pub struct Metrics { + /// Total number of requests received + pub requests_total: AtomicU64, + /// Total number of successful requests + pub requests_success: AtomicU64, + /// Total number of failed requests + pub requests_failed: AtomicU64, + /// Total number of CONNECT requests + pub connect_requests: AtomicU64, + /// Total number of HTTP (non-CONNECT) requests + pub http_requests: AtomicU64, + /// Total number of direct connections (bypassing upstream) + pub direct_connections: AtomicU64, + /// Total bytes received from clients + pub bytes_received: AtomicU64, + /// Total bytes sent to clients + pub bytes_sent: AtomicU64, + /// Current active connections + pub active_connections: AtomicU64, + /// Total health check probes + pub health_checks_total: AtomicU64, + /// Successful health check probes + pub health_checks_success: AtomicU64, + /// Hot reload count + pub hot_reloads: AtomicU64, + /// Start time for uptime calculation + start_time: Instant, +} + +impl Metrics { + pub fn new() -> Arc { + Arc::new(Self { + requests_total: AtomicU64::new(0), + requests_success: AtomicU64::new(0), + requests_failed: AtomicU64::new(0), + connect_requests: AtomicU64::new(0), + http_requests: AtomicU64::new(0), + direct_connections: AtomicU64::new(0), + bytes_received: AtomicU64::new(0), + bytes_sent: AtomicU64::new(0), + active_connections: AtomicU64::new(0), + health_checks_total: AtomicU64::new(0), + health_checks_success: AtomicU64::new(0), + hot_reloads: AtomicU64::new(0), + start_time: Instant::now(), + }) + } + + pub fn uptime_secs(&self) -> u64 { + self.start_time.elapsed().as_secs() + } + + pub fn inc_requests_total(&self) { + self.requests_total.fetch_add(1, Ordering::Relaxed); + } + + pub fn inc_requests_success(&self) { + self.requests_success.fetch_add(1, Ordering::Relaxed); + } + + pub fn inc_requests_failed(&self) { + self.requests_failed.fetch_add(1, Ordering::Relaxed); + } + + pub fn inc_connect(&self) { + self.connect_requests.fetch_add(1, Ordering::Relaxed); + } + + pub fn inc_http(&self) { + self.http_requests.fetch_add(1, Ordering::Relaxed); + } + + pub fn inc_direct(&self) { + self.direct_connections.fetch_add(1, Ordering::Relaxed); + } + + #[allow(dead_code)] + pub fn add_bytes_received(&self, n: u64) { + self.bytes_received.fetch_add(n, Ordering::Relaxed); + } + + #[allow(dead_code)] + pub fn add_bytes_sent(&self, n: u64) { + self.bytes_sent.fetch_add(n, Ordering::Relaxed); + } + + pub fn inc_active(&self) { + self.active_connections.fetch_add(1, Ordering::Relaxed); + } + + pub fn dec_active(&self) { + self.active_connections.fetch_sub(1, Ordering::Relaxed); + } + + pub fn inc_health_check(&self) { + self.health_checks_total.fetch_add(1, Ordering::Relaxed); + } + + pub fn inc_health_check_success(&self) { + self.health_checks_success.fetch_add(1, Ordering::Relaxed); + } + + pub fn inc_hot_reload(&self) { + self.hot_reloads.fetch_add(1, Ordering::Relaxed); + } +} + +impl Default for Metrics { + fn default() -> Self { + Self { + requests_total: AtomicU64::new(0), + requests_success: AtomicU64::new(0), + requests_failed: AtomicU64::new(0), + connect_requests: AtomicU64::new(0), + http_requests: AtomicU64::new(0), + direct_connections: AtomicU64::new(0), + bytes_received: AtomicU64::new(0), + bytes_sent: AtomicU64::new(0), + active_connections: AtomicU64::new(0), + health_checks_total: AtomicU64::new(0), + health_checks_success: AtomicU64::new(0), + hot_reloads: AtomicU64::new(0), + start_time: Instant::now(), + } + } +} + +// --------------------------------------------------------------------------- +// Admin server +// --------------------------------------------------------------------------- + +/// Run the admin HTTP server on the given address. +pub async fn run_admin_server(addr: String, pool: Arc, metrics: Arc) { + let listener = match TcpListener::bind(&addr).await { + Ok(l) => l, + Err(e) => { + error!(addr = %addr, error = %e, "failed to bind admin server"); + return; + } + }; + + info!(addr = %addr, "admin server listening"); + + loop { + match listener.accept().await { + Ok((stream, _)) => { + let pool = Arc::clone(&pool); + let metrics = Arc::clone(&metrics); + tokio::spawn(async move { + if let Err(e) = handle_admin_request(stream, &pool, &metrics).await { + debug!(error = %e, "admin request error"); + } + }); + } + Err(e) => { + error!(error = %e, "admin accept error"); + } + } + } +} + +async fn handle_admin_request( + mut stream: TcpStream, + pool: &Arc, + metrics: &Arc, +) -> std::io::Result<()> { + let mut buf = [0u8; 1024]; + let n = stream.read(&mut buf).await?; + if n == 0 { + return Ok(()); + } + + let request = String::from_utf8_lossy(&buf[..n]); + let first_line = request.lines().next().unwrap_or(""); + + if first_line.starts_with("GET /metrics") { + let body = build_prometheus_metrics(pool, metrics); + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes()).await?; + } else if first_line.starts_with("GET /status") { + let body = build_status_json(pool, metrics); + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes()).await?; + } else if first_line.starts_with("GET /health") { + let body = r#"{"status":"ok"}"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes()).await?; + } else { + let body = "Not Found"; + let response = format!( + "HTTP/1.1 404 Not Found\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes()).await?; + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Prometheus metrics format +// --------------------------------------------------------------------------- + +fn build_prometheus_metrics(pool: &Arc, metrics: &Arc) -> String { + let mut out = String::with_capacity(4096); + + // Basic metrics + out.push_str("# HELP http_proxy_lb_uptime_seconds Time since proxy started\n"); + out.push_str("# TYPE http_proxy_lb_uptime_seconds gauge\n"); + out.push_str(&format!( + "http_proxy_lb_uptime_seconds {}\n", + metrics.uptime_secs() + )); + + out.push_str("# HELP http_proxy_lb_requests_total Total number of requests received\n"); + out.push_str("# TYPE http_proxy_lb_requests_total counter\n"); + out.push_str(&format!( + "http_proxy_lb_requests_total {}\n", + metrics.requests_total.load(Ordering::Relaxed) + )); + + out.push_str("# HELP http_proxy_lb_requests_success Total number of successful requests\n"); + out.push_str("# TYPE http_proxy_lb_requests_success counter\n"); + out.push_str(&format!( + "http_proxy_lb_requests_success {}\n", + metrics.requests_success.load(Ordering::Relaxed) + )); + + out.push_str("# HELP http_proxy_lb_requests_failed Total number of failed requests\n"); + out.push_str("# TYPE http_proxy_lb_requests_failed counter\n"); + out.push_str(&format!( + "http_proxy_lb_requests_failed {}\n", + metrics.requests_failed.load(Ordering::Relaxed) + )); + + out.push_str("# HELP http_proxy_lb_connect_requests Total number of CONNECT requests\n"); + out.push_str("# TYPE http_proxy_lb_connect_requests counter\n"); + out.push_str(&format!( + "http_proxy_lb_connect_requests {}\n", + metrics.connect_requests.load(Ordering::Relaxed) + )); + + out.push_str("# HELP http_proxy_lb_http_requests Total number of HTTP requests\n"); + out.push_str("# TYPE http_proxy_lb_http_requests counter\n"); + out.push_str(&format!( + "http_proxy_lb_http_requests {}\n", + metrics.http_requests.load(Ordering::Relaxed) + )); + + out.push_str( + "# HELP http_proxy_lb_direct_connections Total number of direct (non-proxy) connections\n", + ); + out.push_str("# TYPE http_proxy_lb_direct_connections counter\n"); + out.push_str(&format!( + "http_proxy_lb_direct_connections {}\n", + metrics.direct_connections.load(Ordering::Relaxed) + )); + + out.push_str("# HELP http_proxy_lb_bytes_received Total bytes received from clients\n"); + out.push_str("# TYPE http_proxy_lb_bytes_received counter\n"); + out.push_str(&format!( + "http_proxy_lb_bytes_received {}\n", + metrics.bytes_received.load(Ordering::Relaxed) + )); + + out.push_str("# HELP http_proxy_lb_bytes_sent Total bytes sent to clients\n"); + out.push_str("# TYPE http_proxy_lb_bytes_sent counter\n"); + out.push_str(&format!( + "http_proxy_lb_bytes_sent {}\n", + metrics.bytes_sent.load(Ordering::Relaxed) + )); + + out.push_str("# HELP http_proxy_lb_active_connections Current number of active connections\n"); + out.push_str("# TYPE http_proxy_lb_active_connections gauge\n"); + out.push_str(&format!( + "http_proxy_lb_active_connections {}\n", + metrics.active_connections.load(Ordering::Relaxed) + )); + + out.push_str("# HELP http_proxy_lb_health_checks_total Total health check probes\n"); + out.push_str("# TYPE http_proxy_lb_health_checks_total counter\n"); + out.push_str(&format!( + "http_proxy_lb_health_checks_total {}\n", + metrics.health_checks_total.load(Ordering::Relaxed) + )); + + out.push_str( + "# HELP http_proxy_lb_health_checks_success Successful health check probes\n", + ); + out.push_str("# TYPE http_proxy_lb_health_checks_success counter\n"); + out.push_str(&format!( + "http_proxy_lb_health_checks_success {}\n", + metrics.health_checks_success.load(Ordering::Relaxed) + )); + + out.push_str("# HELP http_proxy_lb_hot_reloads Total config hot reloads\n"); + out.push_str("# TYPE http_proxy_lb_hot_reloads counter\n"); + out.push_str(&format!( + "http_proxy_lb_hot_reloads {}\n", + metrics.hot_reloads.load(Ordering::Relaxed) + )); + + // Per-upstream metrics + out.push_str("# HELP http_proxy_lb_upstream_online Whether upstream is online (1) or offline (0)\n"); + out.push_str("# TYPE http_proxy_lb_upstream_online gauge\n"); + + out.push_str("# HELP http_proxy_lb_upstream_active_connections Active connections per upstream\n"); + out.push_str("# TYPE http_proxy_lb_upstream_active_connections gauge\n"); + + out.push_str("# HELP http_proxy_lb_upstream_latency_ms Latency EMA in milliseconds per upstream\n"); + out.push_str("# TYPE http_proxy_lb_upstream_latency_ms gauge\n"); + + out.push_str("# HELP http_proxy_lb_upstream_consecutive_failures Consecutive failures per upstream\n"); + out.push_str("# TYPE http_proxy_lb_upstream_consecutive_failures gauge\n"); + + for status in pool.all_status() { + let url_escaped = status.url.replace('"', "\\\""); + out.push_str(&format!( + "http_proxy_lb_upstream_online{{url=\"{}\"}} {}\n", + url_escaped, + if status.online { 1 } else { 0 } + )); + out.push_str(&format!( + "http_proxy_lb_upstream_active_connections{{url=\"{}\"}} {}\n", + url_escaped, status.active_conns + )); + out.push_str(&format!( + "http_proxy_lb_upstream_latency_ms{{url=\"{}\"}} {}\n", + url_escaped, status.latency_ema_ms + )); + out.push_str(&format!( + "http_proxy_lb_upstream_consecutive_failures{{url=\"{}\"}} {}\n", + url_escaped, status.consec_failures + )); + } + + out +} + +// --------------------------------------------------------------------------- +// JSON status +// --------------------------------------------------------------------------- + +fn build_status_json(pool: &Arc, metrics: &Arc) -> String { + let upstreams: Vec = pool + .all_status() + .into_iter() + .map(|s| { + format!( + r#"{{"url":"{}","online":{},"active_connections":{},"latency_ms":{},"consecutive_failures":{},"weight":{},"priority":{}}}"#, + s.url.replace('"', "\\\""), + s.online, + s.active_conns, + s.latency_ema_ms, + s.consec_failures, + s.weight, + s.priority + ) + }) + .collect(); + + format!( + r#"{{"uptime_seconds":{},"requests_total":{},"requests_success":{},"requests_failed":{},"connect_requests":{},"http_requests":{},"direct_connections":{},"bytes_received":{},"bytes_sent":{},"active_connections":{},"health_checks_total":{},"health_checks_success":{},"hot_reloads":{},"upstreams":[{}]}}"#, + metrics.uptime_secs(), + metrics.requests_total.load(Ordering::Relaxed), + metrics.requests_success.load(Ordering::Relaxed), + metrics.requests_failed.load(Ordering::Relaxed), + metrics.connect_requests.load(Ordering::Relaxed), + metrics.http_requests.load(Ordering::Relaxed), + metrics.direct_connections.load(Ordering::Relaxed), + metrics.bytes_received.load(Ordering::Relaxed), + metrics.bytes_sent.load(Ordering::Relaxed), + metrics.active_connections.load(Ordering::Relaxed), + metrics.health_checks_total.load(Ordering::Relaxed), + metrics.health_checks_success.load(Ordering::Relaxed), + metrics.hot_reloads.load(Ordering::Relaxed), + upstreams.join(",") + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_metrics_increment() { + let m = Metrics::new(); + assert_eq!(m.requests_total.load(Ordering::Relaxed), 0); + m.inc_requests_total(); + m.inc_requests_total(); + assert_eq!(m.requests_total.load(Ordering::Relaxed), 2); + } + + #[test] + fn test_metrics_active_connections() { + let m = Metrics::new(); + m.inc_active(); + m.inc_active(); + assert_eq!(m.active_connections.load(Ordering::Relaxed), 2); + m.dec_active(); + assert_eq!(m.active_connections.load(Ordering::Relaxed), 1); + } +} diff --git a/src/config.rs b/src/config.rs index 42edd38..f867ea1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,6 +13,11 @@ pub struct Config { /// Local listen address, e.g. "127.0.0.1:8080" pub listen: String, + /// Optional admin/metrics listen address, e.g. "127.0.0.1:9090" + /// Exposes /metrics (Prometheus) and /status (JSON) endpoints. + #[serde(default)] + pub admin_listen: Option, + /// Load-balancing mode #[serde(default)] pub mode: BalanceMode, @@ -29,6 +34,14 @@ pub struct Config { #[serde(default)] pub domain_policy: DomainPolicyConfig, + /// Resource limits and timeouts + #[serde(default)] + pub limits: LimitsConfig, + + /// Enable access logging (default: false) + #[serde(default)] + pub access_log: bool, + /// Upstream proxy list #[serde(default)] pub upstream: Vec, @@ -117,6 +130,39 @@ fn default_hc_timeout() -> u64 { 5 } +// --------------------------------------------------------------------------- +// Limits config +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LimitsConfig { + /// Maximum concurrent client connections (0 = unlimited) + #[serde(default)] + pub max_connections: usize, + + /// Request timeout in seconds (0 = unlimited) + #[serde(default)] + pub request_timeout_secs: u64, + + /// Graceful shutdown timeout in seconds + #[serde(default = "default_shutdown_timeout")] + pub shutdown_timeout_secs: u64, +} + +impl Default for LimitsConfig { + fn default() -> Self { + Self { + max_connections: 0, + request_timeout_secs: 0, + shutdown_timeout_secs: default_shutdown_timeout(), + } + } +} + +fn default_shutdown_timeout() -> u64 { + 30 +} + // --------------------------------------------------------------------------- // Upstream config // --------------------------------------------------------------------------- @@ -310,4 +356,38 @@ upstream: [] assert_eq!(cfg.domain_policy.domains.len(), 2); assert_eq!(cfg.domain_policy.domains[0], "example.com"); } + + #[test] + fn test_parse_yaml_limits() { + let yaml = r#" +listen: "0.0.0.0:8080" +admin_listen: "127.0.0.1:9090" +access_log: true +limits: + max_connections: 1000 + request_timeout_secs: 60 + shutdown_timeout_secs: 15 +upstream: [] +"#; + let cfg: Config = yaml_serde::from_str(yaml).unwrap(); + assert_eq!(cfg.admin_listen, Some("127.0.0.1:9090".to_string())); + assert!(cfg.access_log); + assert_eq!(cfg.limits.max_connections, 1000); + assert_eq!(cfg.limits.request_timeout_secs, 60); + assert_eq!(cfg.limits.shutdown_timeout_secs, 15); + } + + #[test] + fn test_parse_yaml_limits_defaults() { + let yaml = r#" +listen: "0.0.0.0:8080" +upstream: [] +"#; + let cfg: Config = yaml_serde::from_str(yaml).unwrap(); + assert!(cfg.admin_listen.is_none()); + assert!(!cfg.access_log); + assert_eq!(cfg.limits.max_connections, 0); + assert_eq!(cfg.limits.request_timeout_secs, 0); + assert_eq!(cfg.limits.shutdown_timeout_secs, 30); + } } diff --git a/src/health.rs b/src/health.rs index 1cda8e7..a62b112 100644 --- a/src/health.rs +++ b/src/health.rs @@ -6,6 +6,7 @@ use tokio::net::TcpStream; use tokio::time::timeout; use tracing::{debug, info}; +use crate::admin::Metrics; use crate::upstream::UpstreamPool; const CAPTIVE_PROBE_URL: &str = "http://connectivitycheck.gstatic.com/generate_204"; @@ -19,17 +20,22 @@ const PROBE_RESPONSE_BUFFER_SIZE: usize = 256; /// sends a captive-style HTTP probe request through the upstream proxy with /// `timeout_secs` deadline, and marks the upstream **online** if it returns /// HTTP 204. -pub async fn run_health_checker(pool: Arc, interval_secs: u64, timeout_secs: u64) { +pub async fn run_health_checker( + pool: Arc, + metrics: Arc, + interval_secs: u64, + timeout_secs: u64, +) { let interval = Duration::from_secs(interval_secs); let probe_timeout = Duration::from_secs(timeout_secs); loop { tokio::time::sleep(interval).await; - probe_offline(&pool, probe_timeout).await; + probe_offline(&pool, &metrics, probe_timeout).await; } } -async fn probe_offline(pool: &Arc, probe_timeout: Duration) { +async fn probe_offline(pool: &Arc, metrics: &Arc, probe_timeout: Duration) { let offline = pool.offline_entries(); if offline.is_empty() { return; @@ -44,7 +50,9 @@ async fn probe_offline(pool: &Arc, probe_timeout: Duration) { .into_iter() .map(|entry| { let t = probe_timeout; + let metrics = Arc::clone(metrics); tokio::spawn(async move { + metrics.inc_health_check(); let addr = match entry.host_port() { Ok((h, p)) => format!("{h}:{p}"), Err(e) => { @@ -56,6 +64,7 @@ async fn probe_offline(pool: &Arc, probe_timeout: Duration) { Ok(true) => { info!(upstream = %entry.config.url, "health check OK (HTTP 204) — marking online"); entry.mark_online(); + metrics.inc_health_check_success(); } Ok(false) => { debug!(upstream = %entry.config.url, "health check got non-204 response"); @@ -120,7 +129,7 @@ async fn probe_proxy_by_captive_http( mod tests { use super::*; use crate::config::{ - BalanceMode, Config, DomainPolicyConfig, HealthCheckConfig, UpstreamConfig, + BalanceMode, Config, DomainPolicyConfig, HealthCheckConfig, LimitsConfig, UpstreamConfig, }; use crate::upstream::UpstreamPool; use tokio::net::TcpListener; @@ -142,10 +151,13 @@ mod tests { let cfg = Config { listen: "127.0.0.1:8080".to_string(), + admin_listen: None, mode: BalanceMode::RoundRobin, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), domain_policy: DomainPolicyConfig::default(), + limits: LimitsConfig::default(), + access_log: false, upstream: vec![UpstreamConfig { url: format!("http://127.0.0.1:{port}"), weight: 1, @@ -155,11 +167,12 @@ mod tests { }], }; let pool = UpstreamPool::from_config(&cfg); + let metrics = Metrics::new(); let (_, entry) = pool.select(BalanceMode::RoundRobin, &[]).unwrap(); entry.mark_offline(); assert!(!entry.is_online()); - probe_offline(&pool, Duration::from_secs(2)).await; + probe_offline(&pool, &metrics, Duration::from_secs(2)).await; assert!(entry.is_online()); } @@ -179,10 +192,13 @@ mod tests { let cfg = Config { listen: "127.0.0.1:8080".to_string(), + admin_listen: None, mode: BalanceMode::RoundRobin, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), domain_policy: DomainPolicyConfig::default(), + limits: LimitsConfig::default(), + access_log: false, upstream: vec![UpstreamConfig { url: format!("http://127.0.0.1:{port}"), weight: 1, @@ -192,11 +208,12 @@ mod tests { }], }; let pool = UpstreamPool::from_config(&cfg); + let metrics = Metrics::new(); let (_, entry) = pool.select(BalanceMode::RoundRobin, &[]).unwrap(); entry.mark_offline(); assert!(!entry.is_online()); - probe_offline(&pool, Duration::from_secs(2)).await; + probe_offline(&pool, &metrics, Duration::from_secs(2)).await; assert!(!entry.is_online()); } } diff --git a/src/main.rs b/src/main.rs index 2621fab..b219368 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,20 @@ +mod admin; mod config; mod health; mod proxy; mod upstream; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Duration; use anyhow::{Context, Result}; use clap::Parser; -use tokio::net::TcpListener; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::Semaphore; use tracing::{error, info, warn}; +use admin::Metrics; use config::{file_mtime, load_config}; use upstream::UpstreamPool; @@ -26,6 +30,10 @@ struct Cli { /// Path to the YAML configuration file #[arg(short, long, default_value = "config.yaml")] config: String, + + /// Validate configuration file and exit + #[arg(long)] + check: bool, } // --------------------------------------------------------------------------- @@ -49,64 +57,273 @@ async fn main() -> Result<()> { let cfg = load_config(&cfg_path).with_context(|| format!("failed to load config from {cfg_path}"))?; + // Config check mode + if cli.check { + println!("Configuration file '{}' is valid.", cfg_path); + println!(" Listen: {}", cfg.listen); + if let Some(ref admin) = cfg.admin_listen { + println!(" Admin: {}", admin); + } + println!(" Mode: {:?}", cfg.mode); + println!(" Upstreams: {}", cfg.upstream.len()); + for u in &cfg.upstream { + println!(" - {} (weight={}, priority={})", u.url, u.weight, u.priority); + } + println!(" Domain policy: {:?}", cfg.domain_policy.mode); + println!(" Access log: {}", cfg.access_log); + println!(" Limits:"); + println!(" max_connections: {}", if cfg.limits.max_connections == 0 { "unlimited".to_string() } else { cfg.limits.max_connections.to_string() }); + println!(" request_timeout: {}s", if cfg.limits.request_timeout_secs == 0 { "unlimited".to_string() } else { cfg.limits.request_timeout_secs.to_string() }); + println!(" shutdown_timeout: {}s", cfg.limits.shutdown_timeout_secs); + return Ok(()); + } + info!(listen = %cfg.listen, mode = ?cfg.mode, upstreams = cfg.upstream.len(), "starting"); let listen_addr = cfg.listen.clone(); + let admin_listen = cfg.admin_listen.clone(); let mode = cfg.mode; let reload_interval = cfg.reload_interval_secs; let hc_interval = cfg.health_check.interval_secs; let hc_timeout = cfg.health_check.timeout_secs; let domain_policy = Arc::new(cfg.domain_policy.clone()); + let access_log = cfg.access_log; + let max_connections = cfg.limits.max_connections; + let request_timeout = if cfg.limits.request_timeout_secs > 0 { + Some(Duration::from_secs(cfg.limits.request_timeout_secs)) + } else { + None + }; + let shutdown_timeout = Duration::from_secs(cfg.limits.shutdown_timeout_secs); // Build upstream pool let pool = UpstreamPool::from_config(&cfg); + // Create metrics + let metrics = Metrics::new(); + + // Connection limiter (None if unlimited) + let conn_semaphore = if max_connections > 0 { + Some(Arc::new(Semaphore::new(max_connections))) + } else { + None + }; + + // Shutdown flag + let shutdown = Arc::new(AtomicBool::new(false)); + // Bind listener let listener = TcpListener::bind(&listen_addr) .await .with_context(|| format!("failed to bind to {listen_addr}"))?; info!(addr = %listen_addr, "listening"); + // --- Spawn admin server --- + if let Some(admin_addr) = admin_listen { + let pool = Arc::clone(&pool); + let metrics = Arc::clone(&metrics); + tokio::spawn(async move { + admin::run_admin_server(admin_addr, pool, metrics).await; + }); + } + // --- Spawn active health checker --- { let pool = Arc::clone(&pool); + let metrics = Arc::clone(&metrics); tokio::spawn(async move { - health::run_health_checker(pool, hc_interval, hc_timeout).await; + health::run_health_checker(pool, metrics, hc_interval, hc_timeout).await; }); } // --- Spawn hot-reload watcher --- if reload_interval > 0 { let pool = Arc::clone(&pool); + let metrics = Arc::clone(&metrics); let path = cfg_path.clone(); tokio::spawn(async move { - run_hot_reload(pool, path, reload_interval).await; + run_hot_reload(pool, metrics, path, reload_interval).await; }); } - // --- Accept loop --- + // --- Run accept loop with graceful shutdown --- + run_server( + listener, + pool, + mode, + domain_policy, + metrics, + conn_semaphore, + request_timeout, + access_log, + shutdown, + shutdown_timeout, + ) + .await +} + +// --------------------------------------------------------------------------- +// Server with graceful shutdown +// --------------------------------------------------------------------------- + +#[cfg(unix)] +#[allow(clippy::too_many_arguments)] +async fn run_server( + listener: TcpListener, + pool: Arc, + mode: config::BalanceMode, + domain_policy: Arc, + metrics: Arc, + conn_semaphore: Option>, + request_timeout: Option, + access_log: bool, + shutdown: Arc, + shutdown_timeout: Duration, +) -> Result<()> { + use tokio::signal::unix::{signal, SignalKind}; + + let mut sigterm = signal(SignalKind::terminate()).expect("failed to register SIGTERM handler"); + loop { - match listener.accept().await { - Ok((stream, peer)) => { - debug_assert_ne!(peer.port(), 0); - let pool = Arc::clone(&pool); - let domain_policy = Arc::clone(&domain_policy); - tokio::spawn(async move { - proxy::handle_client(stream, pool, mode, domain_policy).await; - }); + tokio::select! { + result = listener.accept() => { + handle_accept( + result, &pool, mode, &domain_policy, &metrics, + &conn_semaphore, request_timeout, access_log + ); } - Err(e) => { - error!(error = %e, "accept error"); + _ = tokio::signal::ctrl_c() => { + info!("received SIGINT, initiating graceful shutdown"); + shutdown.store(true, Ordering::SeqCst); + break; + } + _ = sigterm.recv() => { + info!("received SIGTERM, initiating graceful shutdown"); + shutdown.store(true, Ordering::SeqCst); + break; } } } + + wait_for_shutdown(&metrics, shutdown_timeout).await; + Ok(()) +} + +#[cfg(not(unix))] +#[allow(clippy::too_many_arguments)] +async fn run_server( + listener: TcpListener, + pool: Arc, + mode: config::BalanceMode, + domain_policy: Arc, + metrics: Arc, + conn_semaphore: Option>, + request_timeout: Option, + access_log: bool, + shutdown: Arc, + shutdown_timeout: Duration, +) -> Result<()> { + loop { + tokio::select! { + result = listener.accept() => { + handle_accept( + result, &pool, mode, &domain_policy, &metrics, + &conn_semaphore, request_timeout, access_log + ); + } + _ = tokio::signal::ctrl_c() => { + info!("received SIGINT, initiating graceful shutdown"); + shutdown.store(true, Ordering::SeqCst); + break; + } + } + } + + wait_for_shutdown(&metrics, shutdown_timeout).await; + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +fn handle_accept( + result: std::io::Result<(TcpStream, std::net::SocketAddr)>, + pool: &Arc, + mode: config::BalanceMode, + domain_policy: &Arc, + metrics: &Arc, + conn_semaphore: &Option>, + request_timeout: Option, + access_log: bool, +) { + match result { + Ok((stream, peer)) => { + // Check connection limit + let permit = if let Some(ref sem) = conn_semaphore { + match sem.clone().try_acquire_owned() { + Ok(p) => Some(p), + Err(_) => { + warn!(peer = %peer, "connection limit reached, rejecting"); + drop(stream); + return; + } + } + } else { + None + }; + + let pool = Arc::clone(pool); + let domain_policy = Arc::clone(domain_policy); + let metrics = Arc::clone(metrics); + + metrics.inc_active(); + + tokio::spawn(async move { + if let Some(timeout) = request_timeout { + let _ = tokio::time::timeout( + timeout, + proxy::handle_client(stream, pool, mode, domain_policy, &metrics, access_log), + ) + .await; + } else { + proxy::handle_client(stream, pool, mode, domain_policy, &metrics, access_log).await; + } + metrics.dec_active(); + drop(permit); + }); + } + Err(e) => { + error!(error = %e, "accept error"); + } + } +} + +async fn wait_for_shutdown(metrics: &Arc, shutdown_timeout: Duration) { + info!(timeout_secs = shutdown_timeout.as_secs(), "waiting for active connections to complete"); + + let wait_start = std::time::Instant::now(); + loop { + let active = metrics.active_connections.load(Ordering::Relaxed); + if active == 0 { + info!("all connections closed, shutting down"); + break; + } + if wait_start.elapsed() >= shutdown_timeout { + warn!(active = active, "shutdown timeout reached, forcing shutdown"); + break; + } + tokio::time::sleep(Duration::from_millis(100)).await; + } } // --------------------------------------------------------------------------- // Hot-reload loop // --------------------------------------------------------------------------- -async fn run_hot_reload(pool: Arc, cfg_path: String, interval_secs: u64) { +async fn run_hot_reload( + pool: Arc, + metrics: Arc, + cfg_path: String, + interval_secs: u64, +) { let mut last_mtime = file_mtime(&cfg_path); let interval = Duration::from_secs(interval_secs); @@ -122,6 +339,7 @@ async fn run_hot_reload(pool: Arc, cfg_path: String, interval_secs match load_config(&cfg_path) { Ok(new_cfg) => { pool.reload(&new_cfg); + metrics.inc_hot_reload(); last_mtime = current_mtime; } Err(e) => { @@ -130,3 +348,4 @@ async fn run_hot_reload(pool: Arc, cfg_path: String, interval_secs } } } + diff --git a/src/proxy.rs b/src/proxy.rs index cc76f3b..1f47d07 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -13,8 +13,9 @@ use anyhow::{anyhow, bail, Result}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::time::{timeout, Duration}; -use tracing::{debug, warn}; +use tracing::{debug, info, warn}; +use crate::admin::Metrics; use crate::config::{BalanceMode, DomainPolicyConfig, DomainPolicyMode}; use crate::upstream::UpstreamPool; @@ -38,6 +39,8 @@ pub async fn handle_client( pool: Arc, mode: BalanceMode, domain_policy: Arc, + metrics: &Metrics, + access_log: bool, ) { let peer = client .peer_addr() @@ -48,10 +51,13 @@ pub async fn handle_client( loop { match read_headers(&mut client).await { Ok(Some(buf)) => { - match dispatch(&mut client, &buf, &pool, mode, &domain_policy).await { + metrics.inc_requests_total(); + let start = Instant::now(); + match dispatch(&mut client, &buf, &pool, mode, &domain_policy, metrics, access_log, &peer, start).await { Ok(true) => continue, // keep-alive: read next request Ok(false) => break, Err(e) => { + metrics.inc_requests_failed(); debug!(peer = %peer, error = %e, "request dispatch error"); break; } @@ -73,12 +79,17 @@ pub async fn handle_client( // --------------------------------------------------------------------------- /// Returns `Ok(true)` if the client connection should be kept alive. +#[allow(clippy::too_many_arguments)] async fn dispatch( client: &mut TcpStream, buf: &[u8], pool: &Arc, mode: BalanceMode, domain_policy: &Arc, + metrics: &Metrics, + access_log: bool, + peer: &str, + start: Instant, ) -> Result { // --- parse request line + headers --- let mut raw_headers = [httparse::EMPTY_HEADER; 96]; @@ -108,17 +119,42 @@ async fn dispatch( // --- CONNECT (HTTPS tunnel) --- if method.eq_ignore_ascii_case("CONNECT") { + metrics.inc_connect(); let (host, port) = parse_connect_target(&path)?; let use_proxy = should_use_proxy(&host, domain_policy); - return handle_connect(client, &host, port, pool, mode, use_proxy) - .await - .map(|_| false); + if !use_proxy { + metrics.inc_direct(); + } + let result = handle_connect(client, &host, port, pool, mode, use_proxy).await; + let elapsed_ms = start.elapsed().as_millis(); + let status = if result.is_ok() { 200 } else { 502 }; + if result.is_ok() { + metrics.inc_requests_success(); + } else { + metrics.inc_requests_failed(); + } + if access_log { + info!( + peer = %peer, + method = "CONNECT", + target = %path, + status = status, + elapsed_ms = elapsed_ms, + direct = !use_proxy, + "access" + ); + } + return result.map(|_| false); } // --- Plain HTTP --- + metrics.inc_http(); let (target_host, target_port) = parse_http_target(&path, headers)?; let use_proxy = should_use_proxy(&target_host, domain_policy); - handle_http( + if !use_proxy { + metrics.inc_direct(); + } + let result = handle_http( client, &method, &path, @@ -132,7 +168,26 @@ async fn dispatch( use_proxy, (&target_host, target_port), ) - .await + .await; + let elapsed_ms = start.elapsed().as_millis(); + let status = if result.is_ok() { 200 } else { 502 }; + if result.is_ok() { + metrics.inc_requests_success(); + } else { + metrics.inc_requests_failed(); + } + if access_log { + info!( + peer = %peer, + method = %method, + target = %path, + status = status, + elapsed_ms = elapsed_ms, + direct = !use_proxy, + "access" + ); + } + result } // --------------------------------------------------------------------------- diff --git a/src/upstream.rs b/src/upstream.rs index 77cc628..c9efb0c 100644 --- a/src/upstream.rs +++ b/src/upstream.rs @@ -131,6 +131,21 @@ impl UpstreamEntry { } } +// --------------------------------------------------------------------------- +// UpstreamStatus — snapshot for admin API +// --------------------------------------------------------------------------- + +/// Status snapshot of an upstream for metrics/admin API. +pub struct UpstreamStatus { + pub url: String, + pub online: bool, + pub active_conns: i64, + pub latency_ema_ms: u64, + pub consec_failures: u32, + pub weight: u32, + pub priority: u32, +} + // --------------------------------------------------------------------------- // UpstreamPool — thread-safe collection of upstream entries // --------------------------------------------------------------------------- @@ -225,6 +240,23 @@ impl UpstreamPool { .collect() } + /// Get status of all upstreams for metrics/admin API. + pub fn all_status(&self) -> Vec { + self.entries + .lock() + .iter() + .map(|e| UpstreamStatus { + url: e.config.url.clone(), + online: e.is_online(), + active_conns: e.active_conns.load(Ordering::Relaxed), + latency_ema_ms: e.latency_ema_ms.load(Ordering::Relaxed), + consec_failures: e.consec_failures.load(Ordering::Relaxed), + weight: e.config.weight, + priority: e.config.priority, + }) + .collect() + } + // ---- hot reload -------------------------------------------------------- /// Replace the upstream list from a new config while preserving state/stats @@ -268,16 +300,19 @@ impl UpstreamPool { mod tests { use super::*; use crate::config::{ - BalanceMode, Config, DomainPolicyConfig, HealthCheckConfig, UpstreamConfig, + BalanceMode, Config, DomainPolicyConfig, HealthCheckConfig, LimitsConfig, UpstreamConfig, }; fn make_pool(urls: &[&str]) -> Arc { let cfg = Config { listen: "127.0.0.1:8080".to_string(), + admin_listen: None, mode: BalanceMode::RoundRobin, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), domain_policy: DomainPolicyConfig::default(), + limits: LimitsConfig::default(), + access_log: false, upstream: urls .iter() .map(|u| UpstreamConfig { @@ -370,10 +405,13 @@ mod tests { // Reload with same upstreams let new_cfg = Config { listen: "127.0.0.1:8080".to_string(), + admin_listen: None, mode: BalanceMode::RoundRobin, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), domain_policy: DomainPolicyConfig::default(), + limits: LimitsConfig::default(), + access_log: false, upstream: vec![ UpstreamConfig { url: "http://a:1".to_string(), @@ -422,10 +460,13 @@ mod tests { fn weighted_round_robin_respects_weights() { let cfg = Config { listen: "127.0.0.1:8080".to_string(), + admin_listen: None, mode: BalanceMode::RoundRobin, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), domain_policy: DomainPolicyConfig::default(), + limits: LimitsConfig::default(), + access_log: false, upstream: vec![ UpstreamConfig { url: "http://a:1".to_string(), @@ -464,10 +505,13 @@ mod tests { fn priority_mode_prefers_lowest_priority_value() { let cfg = Config { listen: "127.0.0.1:8080".to_string(), + admin_listen: None, mode: BalanceMode::Priority, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), domain_policy: DomainPolicyConfig::default(), + limits: LimitsConfig::default(), + access_log: false, upstream: vec![ UpstreamConfig { url: "http://a:1".to_string(), @@ -502,10 +546,13 @@ mod tests { fn priority_mode_falls_back_to_next_online_priority() { let cfg = Config { listen: "127.0.0.1:8080".to_string(), + admin_listen: None, mode: BalanceMode::Priority, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), domain_policy: DomainPolicyConfig::default(), + limits: LimitsConfig::default(), + access_log: false, upstream: vec![ UpstreamConfig { url: "http://a:1".to_string(), @@ -536,10 +583,13 @@ mod tests { fn reload_updates_priority_for_existing_upstream() { let cfg = Config { listen: "127.0.0.1:8080".to_string(), + admin_listen: None, mode: BalanceMode::Priority, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), domain_policy: DomainPolicyConfig::default(), + limits: LimitsConfig::default(), + access_log: false, upstream: vec![ UpstreamConfig { url: "http://a:1".to_string(), @@ -563,10 +613,13 @@ mod tests { let reloaded = Config { listen: "127.0.0.1:8080".to_string(), + admin_listen: None, mode: BalanceMode::Priority, reload_interval_secs: 0, health_check: HealthCheckConfig::default(), domain_policy: DomainPolicyConfig::default(), + limits: LimitsConfig::default(), + access_log: false, upstream: vec![ UpstreamConfig { url: "http://a:1".to_string(), From 258d0b9183d34e9172ec5d0260fafa578e0d777b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=9F=E9=9B=B2=E5=B8=8C=E5=87=AA?= <70424266+MejiroRina@users.noreply.github.com> Date: Mon, 23 Mar 2026 06:45:56 +0800 Subject: [PATCH 10/13] [Fix] Fix some bugs and add more tests --- Cargo.toml | 20 +-- Dockerfile | 5 +- README.md | 30 ++++ src/admin.rs | 20 ++- src/config.rs | 76 +++++++++ src/main.rs | 105 ++++++++----- src/proxy.rs | 428 ++++++++++++++++++++++++++++++++++++-------------- 7 files changed, 503 insertions(+), 181 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index be30cf2..42a8a89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,18 +11,18 @@ path = "src/main.rs" [dependencies] # Async runtime -tokio = { version = "1", features = ["full"] } +tokio = { version = "1.50.0", features = ["full"] } # HTTP/1.x request/response parsing -httparse = "1" +httparse = "1.10.1" # Serialization -serde = { version = "1", features = ["derive"] } -yaml_serde = "0.10" +serde = { version = "1.0.228", features = ["derive"] } +yaml_serde = "0.10.4" # CLI -clap = { version = "4", features = ["derive"] } +clap = { version = "4.6.0", features = ["derive"] } # Logging -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } +tracing = "0.1.44" +tracing-subscriber = { version = "0.3.23", features = ["env-filter", "fmt"] } # Utilities -anyhow = "1" -base64 = "0.22" -parking_lot = "0.12" +anyhow = "1.0.102" +base64 = "0.22.1" +parking_lot = "0.12.5" diff --git a/Dockerfile b/Dockerfile index 7daf6a2..18dce1c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Build stage -FROM rust:1.75-slim-bookworm AS builder +FROM rust:slim-bookworm AS builder WORKDIR /app @@ -21,10 +21,11 @@ FROM debian:bookworm-slim # Install runtime dependencies RUN apt-get update && apt-get install -y --no-install-recommends \ ca-certificates \ + curl \ && rm -rf /var/lib/apt/lists/* # Create non-root user -RUN useradd -r -s /bin/false proxy +RUN getent passwd proxy >/dev/null || useradd -r -s /bin/false proxy # Copy binary from builder COPY --from=builder /app/target/release/http-proxy-lb /usr/local/bin/ diff --git a/README.md b/README.md index b7ae14f..a69e646 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,30 @@ Set `RUST_LOG=debug` for verbose logging. ./target/release/http-proxy-lb --config config.yaml --check ``` +## Validation + +```bash +# unit + integration tests +cargo test + +# lint +cargo clippy -- -D warnings + +# formatting +cargo fmt -- --check + +# container smoke test +./scripts/docker-smoke.sh +``` + +The integration test suite exercises: + +* config validation failure paths +* direct HTTP forwarding with real response-status propagation +* request timeout handling (`504 Gateway Timeout`) +* connection limiting (`503 Service Unavailable`) +* admin metrics/status byte counters and request counters + ## Configuration ```yaml @@ -236,6 +260,12 @@ sudo systemctl enable --now http-proxy-lb docker compose --profile monitoring up -d ``` +For a local container smoke test without touching your own config, run: + +```bash +./scripts/docker-smoke.sh +``` + ## License MIT diff --git a/src/admin.rs b/src/admin.rs index dce8bfa..fb0d324 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -319,9 +319,7 @@ fn build_prometheus_metrics(pool: &Arc, metrics: &Arc) -> metrics.health_checks_total.load(Ordering::Relaxed) )); - out.push_str( - "# HELP http_proxy_lb_health_checks_success Successful health check probes\n", - ); + out.push_str("# HELP http_proxy_lb_health_checks_success Successful health check probes\n"); out.push_str("# TYPE http_proxy_lb_health_checks_success counter\n"); out.push_str(&format!( "http_proxy_lb_health_checks_success {}\n", @@ -336,16 +334,24 @@ fn build_prometheus_metrics(pool: &Arc, metrics: &Arc) -> )); // Per-upstream metrics - out.push_str("# HELP http_proxy_lb_upstream_online Whether upstream is online (1) or offline (0)\n"); + out.push_str( + "# HELP http_proxy_lb_upstream_online Whether upstream is online (1) or offline (0)\n", + ); out.push_str("# TYPE http_proxy_lb_upstream_online gauge\n"); - out.push_str("# HELP http_proxy_lb_upstream_active_connections Active connections per upstream\n"); + out.push_str( + "# HELP http_proxy_lb_upstream_active_connections Active connections per upstream\n", + ); out.push_str("# TYPE http_proxy_lb_upstream_active_connections gauge\n"); - out.push_str("# HELP http_proxy_lb_upstream_latency_ms Latency EMA in milliseconds per upstream\n"); + out.push_str( + "# HELP http_proxy_lb_upstream_latency_ms Latency EMA in milliseconds per upstream\n", + ); out.push_str("# TYPE http_proxy_lb_upstream_latency_ms gauge\n"); - out.push_str("# HELP http_proxy_lb_upstream_consecutive_failures Consecutive failures per upstream\n"); + out.push_str( + "# HELP http_proxy_lb_upstream_consecutive_failures Consecutive failures per upstream\n", + ); out.push_str("# TYPE http_proxy_lb_upstream_consecutive_failures gauge\n"); for status in pool.all_status() { diff --git a/src/config.rs b/src/config.rs index f867ea1..8d0ebfe 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,7 @@ use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; use std::fs; +use std::net::SocketAddr; use std::path::Path; use std::time::SystemTime; @@ -234,6 +235,27 @@ pub fn load_config(path: &str) -> Result { Ok(config) } +pub fn validate_config(config: &Config) -> Result<()> { + config + .listen + .parse::() + .with_context(|| format!("invalid listen address: {}", config.listen))?; + + if let Some(admin_listen) = &config.admin_listen { + admin_listen + .parse::() + .with_context(|| format!("invalid admin_listen address: {admin_listen}"))?; + } + + for upstream in &config.upstream { + upstream + .host_port() + .with_context(|| format!("invalid upstream URL: {}", upstream.url))?; + } + + Ok(()) +} + /// Returns the mtime of `path`, used to detect file changes for hot reload. pub fn file_mtime(path: &str) -> Option { Path::new(path).metadata().ok()?.modified().ok() @@ -390,4 +412,58 @@ upstream: [] assert_eq!(cfg.limits.request_timeout_secs, 0); assert_eq!(cfg.limits.shutdown_timeout_secs, 30); } + + #[test] + fn test_validate_config_rejects_invalid_listen() { + let cfg = Config { + listen: "not-an-addr".to_string(), + admin_listen: None, + mode: BalanceMode::RoundRobin, + reload_interval_secs: 0, + health_check: HealthCheckConfig::default(), + domain_policy: DomainPolicyConfig::default(), + limits: LimitsConfig::default(), + access_log: false, + upstream: vec![], + }; + assert!(validate_config(&cfg).is_err()); + } + + #[test] + fn test_validate_config_rejects_invalid_admin_listen() { + let cfg = Config { + listen: "127.0.0.1:8080".to_string(), + admin_listen: Some("invalid-admin".to_string()), + mode: BalanceMode::RoundRobin, + reload_interval_secs: 0, + health_check: HealthCheckConfig::default(), + domain_policy: DomainPolicyConfig::default(), + limits: LimitsConfig::default(), + access_log: false, + upstream: vec![], + }; + assert!(validate_config(&cfg).is_err()); + } + + #[test] + fn test_validate_config_rejects_invalid_upstream() { + let cfg = Config { + listen: "127.0.0.1:8080".to_string(), + admin_listen: None, + mode: BalanceMode::RoundRobin, + reload_interval_secs: 0, + health_check: HealthCheckConfig::default(), + domain_policy: DomainPolicyConfig::default(), + limits: LimitsConfig::default(), + access_log: false, + upstream: vec![UpstreamConfig { + url: "http://".to_string(), + weight: 1, + priority: 100, + username: None, + password: None, + }], + }; + assert!(validate_config(&cfg).is_err()); + } } diff --git a/src/main.rs b/src/main.rs index b219368..3aebe5c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,11 +11,10 @@ use std::time::Duration; use anyhow::{Context, Result}; use clap::Parser; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::Semaphore; use tracing::{error, info, warn}; use admin::Metrics; -use config::{file_mtime, load_config}; +use config::{file_mtime, load_config, validate_config}; use upstream::UpstreamPool; // --------------------------------------------------------------------------- @@ -56,6 +55,7 @@ async fn main() -> Result<()> { // Load initial config let cfg = load_config(&cfg_path).with_context(|| format!("failed to load config from {cfg_path}"))?; + validate_config(&cfg).with_context(|| format!("invalid config in {cfg_path}"))?; // Config check mode if cli.check { @@ -67,14 +67,34 @@ async fn main() -> Result<()> { println!(" Mode: {:?}", cfg.mode); println!(" Upstreams: {}", cfg.upstream.len()); for u in &cfg.upstream { - println!(" - {} (weight={}, priority={})", u.url, u.weight, u.priority); + println!( + " - {} (weight={}, priority={})", + u.url, u.weight, u.priority + ); } println!(" Domain policy: {:?}", cfg.domain_policy.mode); println!(" Access log: {}", cfg.access_log); println!(" Limits:"); - println!(" max_connections: {}", if cfg.limits.max_connections == 0 { "unlimited".to_string() } else { cfg.limits.max_connections.to_string() }); - println!(" request_timeout: {}s", if cfg.limits.request_timeout_secs == 0 { "unlimited".to_string() } else { cfg.limits.request_timeout_secs.to_string() }); - println!(" shutdown_timeout: {}s", cfg.limits.shutdown_timeout_secs); + println!( + " max_connections: {}", + if cfg.limits.max_connections == 0 { + "unlimited".to_string() + } else { + cfg.limits.max_connections.to_string() + } + ); + println!( + " request_timeout: {}", + if cfg.limits.request_timeout_secs == 0 { + "unlimited".to_string() + } else { + format!("{}s", cfg.limits.request_timeout_secs) + } + ); + println!( + " shutdown_timeout: {}s", + cfg.limits.shutdown_timeout_secs + ); return Ok(()); } @@ -103,11 +123,7 @@ async fn main() -> Result<()> { let metrics = Metrics::new(); // Connection limiter (None if unlimited) - let conn_semaphore = if max_connections > 0 { - Some(Arc::new(Semaphore::new(max_connections))) - } else { - None - }; + let connection_limit = (max_connections > 0).then_some(max_connections); // Shutdown flag let shutdown = Arc::new(AtomicBool::new(false)); @@ -153,7 +169,7 @@ async fn main() -> Result<()> { mode, domain_policy, metrics, - conn_semaphore, + connection_limit, request_timeout, access_log, shutdown, @@ -174,7 +190,7 @@ async fn run_server( mode: config::BalanceMode, domain_policy: Arc, metrics: Arc, - conn_semaphore: Option>, + connection_limit: Option, request_timeout: Option, access_log: bool, shutdown: Arc, @@ -189,7 +205,7 @@ async fn run_server( result = listener.accept() => { handle_accept( result, &pool, mode, &domain_policy, &metrics, - &conn_semaphore, request_timeout, access_log + connection_limit, request_timeout, access_log ); } _ = tokio::signal::ctrl_c() => { @@ -217,7 +233,7 @@ async fn run_server( mode: config::BalanceMode, domain_policy: Arc, metrics: Arc, - conn_semaphore: Option>, + connection_limit: Option, request_timeout: Option, access_log: bool, shutdown: Arc, @@ -228,7 +244,7 @@ async fn run_server( result = listener.accept() => { handle_accept( result, &pool, mode, &domain_policy, &metrics, - &conn_semaphore, request_timeout, access_log + connection_limit, request_timeout, access_log ); } _ = tokio::signal::ctrl_c() => { @@ -250,25 +266,23 @@ fn handle_accept( mode: config::BalanceMode, domain_policy: &Arc, metrics: &Arc, - conn_semaphore: &Option>, + connection_limit: Option, request_timeout: Option, access_log: bool, ) { match result { Ok((stream, peer)) => { // Check connection limit - let permit = if let Some(ref sem) = conn_semaphore { - match sem.clone().try_acquire_owned() { - Ok(p) => Some(p), - Err(_) => { - warn!(peer = %peer, "connection limit reached, rejecting"); - drop(stream); - return; - } + if let Some(limit) = connection_limit { + let active = metrics.active_connections.load(Ordering::Relaxed) as usize; + if active >= limit { + warn!(peer = %peer, limit, "connection limit reached, rejecting"); + tokio::spawn(async move { + let _ = proxy::reject_connection(stream).await; + }); + return; } - } else { - None - }; + } let pool = Arc::clone(pool); let domain_policy = Arc::clone(domain_policy); @@ -277,17 +291,17 @@ fn handle_accept( metrics.inc_active(); tokio::spawn(async move { - if let Some(timeout) = request_timeout { - let _ = tokio::time::timeout( - timeout, - proxy::handle_client(stream, pool, mode, domain_policy, &metrics, access_log), - ) - .await; - } else { - proxy::handle_client(stream, pool, mode, domain_policy, &metrics, access_log).await; - } + proxy::handle_client( + stream, + pool, + mode, + domain_policy, + &metrics, + access_log, + request_timeout, + ) + .await; metrics.dec_active(); - drop(permit); }); } Err(e) => { @@ -297,7 +311,10 @@ fn handle_accept( } async fn wait_for_shutdown(metrics: &Arc, shutdown_timeout: Duration) { - info!(timeout_secs = shutdown_timeout.as_secs(), "waiting for active connections to complete"); + info!( + timeout_secs = shutdown_timeout.as_secs(), + "waiting for active connections to complete" + ); let wait_start = std::time::Instant::now(); loop { @@ -307,7 +324,10 @@ async fn wait_for_shutdown(metrics: &Arc, shutdown_timeout: Duration) { break; } if wait_start.elapsed() >= shutdown_timeout { - warn!(active = active, "shutdown timeout reached, forcing shutdown"); + warn!( + active = active, + "shutdown timeout reached, forcing shutdown" + ); break; } tokio::time::sleep(Duration::from_millis(100)).await; @@ -338,6 +358,10 @@ async fn run_hot_reload( info!(path = %cfg_path, "config file changed — reloading"); match load_config(&cfg_path) { Ok(new_cfg) => { + if let Err(e) = validate_config(&new_cfg) { + warn!(error = %e, "hot reload validation failed — keeping current config"); + continue; + } pool.reload(&new_cfg); metrics.inc_hot_reload(); last_mtime = current_mtime; @@ -348,4 +372,3 @@ async fn run_hot_reload( } } } - diff --git a/src/proxy.rs b/src/proxy.rs index 1f47d07..667ba53 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -28,6 +28,42 @@ const CONNECT_TIMEOUT: Duration = Duration::from_secs(15); /// Maximum number of upstream retry attempts per request. const MAX_RETRIES: usize = 3; +#[derive(Clone, Copy)] +struct ProxyResponse { + keep_alive: bool, + response_status: u16, + success: bool, +} + +impl ProxyResponse { + fn success(response_status: u16, keep_alive: bool) -> Self { + Self { + keep_alive, + response_status, + success: true, + } + } + + fn failure(response_status: u16) -> Self { + Self { + keep_alive: false, + response_status, + success: false, + } + } +} + +struct ResponseForwardOutcome { + response_status: u16, + upstream_keep_alive: bool, +} + +struct RequestLogContext { + method: String, + target: String, + direct: bool, +} + // --------------------------------------------------------------------------- // Public entry point // --------------------------------------------------------------------------- @@ -41,6 +77,7 @@ pub async fn handle_client( domain_policy: Arc, metrics: &Metrics, access_log: bool, + request_timeout: Option, ) { let peer = client .peer_addr() @@ -49,13 +86,84 @@ pub async fn handle_client( debug!(peer = %peer, "new client connection"); loop { - match read_headers(&mut client).await { + let read_result = if let Some(limit) = request_timeout { + match timeout(limit, read_headers(&mut client)).await { + Ok(result) => result, + Err(_) => { + let _ = write_status(&mut client, 408, "Request Timeout", metrics).await; + break; + } + } + } else { + read_headers(&mut client).await + }; + + match read_result { Ok(Some(buf)) => { + metrics.add_bytes_received(buf.len() as u64); metrics.inc_requests_total(); let start = Instant::now(); - match dispatch(&mut client, &buf, &pool, mode, &domain_policy, metrics, access_log, &peer, start).await { - Ok(true) => continue, // keep-alive: read next request - Ok(false) => break, + let log_ctx = request_log_context(&buf, &domain_policy); + + let dispatch_result = if let Some(limit) = request_timeout { + match timeout( + limit, + dispatch(&mut client, &buf, &pool, mode, &domain_policy, metrics), + ) + .await + { + Ok(result) => result, + Err(_) => { + let outcome = ProxyResponse::failure(504); + let _ = write_status( + &mut client, + outcome.response_status, + "Gateway Timeout", + metrics, + ) + .await; + metrics.inc_requests_failed(); + if access_log { + info!( + peer = %peer, + method = %log_ctx.method, + target = %log_ctx.target, + status = outcome.response_status, + elapsed_ms = start.elapsed().as_millis(), + direct = log_ctx.direct, + "access" + ); + } + break; + } + } + } else { + dispatch(&mut client, &buf, &pool, mode, &domain_policy, metrics).await + }; + + match dispatch_result { + Ok(response) => { + if response.success { + metrics.inc_requests_success(); + } else { + metrics.inc_requests_failed(); + } + if access_log { + info!( + peer = %peer, + method = %log_ctx.method, + target = %log_ctx.target, + status = response.response_status, + elapsed_ms = start.elapsed().as_millis(), + direct = log_ctx.direct, + "access" + ); + } + if response.keep_alive { + continue; + } + break; + } Err(e) => { metrics.inc_requests_failed(); debug!(peer = %peer, error = %e, "request dispatch error"); @@ -87,23 +195,36 @@ async fn dispatch( mode: BalanceMode, domain_policy: &Arc, metrics: &Metrics, - access_log: bool, - peer: &str, - start: Instant, -) -> Result { +) -> Result { // --- parse request line + headers --- let mut raw_headers = [httparse::EMPTY_HEADER; 96]; let mut req = httparse::Request::new(&mut raw_headers); - let body_offset = match req.parse(buf)? { - httparse::Status::Complete(n) => n, - httparse::Status::Partial => bail!("incomplete request headers"), + let body_offset = match req.parse(buf) { + Ok(httparse::Status::Complete(n)) => n, + Ok(httparse::Status::Partial) => { + write_status(client, 400, "Bad Request", metrics).await?; + return Ok(ProxyResponse::failure(400)); + } + Err(_) => { + write_status(client, 400, "Bad Request", metrics).await?; + return Ok(ProxyResponse::failure(400)); + } }; - let method = req - .method - .ok_or_else(|| anyhow!("missing method"))? - .to_string(); - let path = req.path.ok_or_else(|| anyhow!("missing path"))?.to_string(); + let method = match req.method { + Some(method) => method.to_string(), + None => { + write_status(client, 400, "Bad Request", metrics).await?; + return Ok(ProxyResponse::failure(400)); + } + }; + let path = match req.path { + Some(path) => path.to_string(), + None => { + write_status(client, 400, "Bad Request", metrics).await?; + return Ok(ProxyResponse::failure(400)); + } + }; let version = req.version.unwrap_or(1); let headers = req.headers; @@ -120,41 +241,34 @@ async fn dispatch( // --- CONNECT (HTTPS tunnel) --- if method.eq_ignore_ascii_case("CONNECT") { metrics.inc_connect(); - let (host, port) = parse_connect_target(&path)?; + let (host, port) = match parse_connect_target(&path) { + Ok(target) => target, + Err(_) => { + write_status(client, 400, "Bad Request", metrics).await?; + return Ok(ProxyResponse::failure(400)); + } + }; let use_proxy = should_use_proxy(&host, domain_policy); if !use_proxy { metrics.inc_direct(); } - let result = handle_connect(client, &host, port, pool, mode, use_proxy).await; - let elapsed_ms = start.elapsed().as_millis(); - let status = if result.is_ok() { 200 } else { 502 }; - if result.is_ok() { - metrics.inc_requests_success(); - } else { - metrics.inc_requests_failed(); - } - if access_log { - info!( - peer = %peer, - method = "CONNECT", - target = %path, - status = status, - elapsed_ms = elapsed_ms, - direct = !use_proxy, - "access" - ); - } - return result.map(|_| false); + return handle_connect(client, &host, port, pool, mode, use_proxy, metrics).await; } // --- Plain HTTP --- metrics.inc_http(); - let (target_host, target_port) = parse_http_target(&path, headers)?; + let (target_host, target_port) = match parse_http_target(&path, headers) { + Ok(target) => target, + Err(_) => { + write_status(client, 400, "Bad Request", metrics).await?; + return Ok(ProxyResponse::failure(400)); + } + }; let use_proxy = should_use_proxy(&target_host, domain_policy); if !use_proxy { metrics.inc_direct(); } - let result = handle_http( + handle_http( client, &method, &path, @@ -167,27 +281,9 @@ async fn dispatch( mode, use_proxy, (&target_host, target_port), + metrics, ) - .await; - let elapsed_ms = start.elapsed().as_millis(); - let status = if result.is_ok() { 200 } else { 502 }; - if result.is_ok() { - metrics.inc_requests_success(); - } else { - metrics.inc_requests_failed(); - } - if access_log { - info!( - peer = %peer, - method = %method, - target = %path, - status = status, - elapsed_ms = elapsed_ms, - direct = !use_proxy, - "access" - ); - } - result + .await } // --------------------------------------------------------------------------- @@ -201,14 +297,24 @@ async fn handle_connect( pool: &Arc, mode: BalanceMode, use_proxy: bool, -) -> Result<()> { + metrics: &Metrics, +) -> Result { if !use_proxy { let addr = format!("{host}:{port}"); - let up_stream = connect_target(&addr).await?; - client - .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n") - .await?; - return tunnel(client, up_stream).await; + let up_stream = match connect_target(&addr).await { + Ok(stream) => stream, + Err(_) => { + write_status(client, 502, "Bad Gateway", metrics).await?; + return Ok(ProxyResponse::failure(502)); + } + }; + let response = b"HTTP/1.1 200 Connection Established\r\n\r\n"; + client.write_all(response).await?; + metrics.add_bytes_sent(response.len() as u64); + if let Err(e) = tunnel(client, up_stream, metrics).await { + debug!(error = %e, target = %addr, "CONNECT tunnel closed with error"); + } + return Ok(ProxyResponse::success(200, false)); } let mut tried: Vec = Vec::new(); @@ -218,8 +324,8 @@ async fn handle_connect( let (idx, upstream) = match pool.select(mode, &tried) { Some(u) => u, None => { - let _ = write_status(client, 502, "No upstream available").await; - bail!("no upstream available for CONNECT"); + write_status(client, 502, "No upstream available", metrics).await?; + return Ok(ProxyResponse::failure(502)); } }; tried.push(idx); @@ -235,8 +341,8 @@ async fn handle_connect( upstream.mark_offline(); upstream.record_failure(); if tried.len() >= max_retries { - let _ = write_status(client, 502, "Bad Gateway").await; - bail!("all retries exhausted for CONNECT"); + write_status(client, 502, "Bad Gateway", metrics).await?; + return Ok(ProxyResponse::failure(502)); } continue; } @@ -249,8 +355,8 @@ async fn handle_connect( upstream.mark_offline(); upstream.record_failure(); if tried.len() >= max_retries { - let _ = write_status(client, 502, "Bad Gateway").await; - bail!("all retries exhausted for CONNECT write"); + write_status(client, 502, "Bad Gateway", metrics).await?; + return Ok(ProxyResponse::failure(502)); } continue; } @@ -259,33 +365,37 @@ async fn handle_connect( match read_connect_response(&mut up_stream).await { Ok(200) => { // Success — tell client we're connected - client - .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n") - .await?; + let response = b"HTTP/1.1 200 Connection Established\r\n\r\n"; + client.write_all(response).await?; + metrics.add_bytes_sent(response.len() as u64); upstream .active_conns .fetch_add(1, std::sync::atomic::Ordering::Relaxed); let t0 = Instant::now(); - let result = tunnel(client, up_stream).await; + let result = tunnel(client, up_stream, metrics).await; upstream .active_conns .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); upstream.record_success(t0.elapsed().as_millis() as u64); - return result; + if let Err(e) = result { + debug!(upstream = %upstream.config.url, error = %e, "CONNECT tunnel closed with error"); + } + return Ok(ProxyResponse::success(200, false)); } Ok(status) => { // Upstream rejected CONNECT (e.g. auth required, forbidden) debug!(upstream = %upstream.config.url, status, "CONNECT rejected by upstream"); - let _ = write_status(client, status as u16, "Upstream rejected CONNECT").await; - bail!("upstream rejected CONNECT with status {status}"); + let client_status = status as u16; + write_status(client, client_status, "Upstream rejected CONNECT", metrics).await?; + return Ok(ProxyResponse::failure(client_status)); } Err(e) => { warn!(upstream = %upstream.config.url, error = %e, "CONNECT: bad upstream response"); upstream.mark_offline(); upstream.record_failure(); if tried.len() >= max_retries { - let _ = write_status(client, 502, "Bad Gateway").await; - bail!("all retries exhausted reading CONNECT response"); + write_status(client, 502, "Bad Gateway", metrics).await?; + return Ok(ProxyResponse::failure(502)); } continue; } @@ -294,13 +404,10 @@ async fn handle_connect( } /// Bidirectional copy between client and upstream until either side closes. -async fn tunnel(client: &mut TcpStream, mut upstream: TcpStream) -> Result<()> { - let (mut cr, mut cw) = tokio::io::split(client); - let (mut ur, mut uw) = tokio::io::split(&mut upstream); - tokio::select! { - r = tokio::io::copy(&mut cr, &mut uw) => { r?; } - r = tokio::io::copy(&mut ur, &mut cw) => { r?; } - } +async fn tunnel(client: &mut TcpStream, mut upstream: TcpStream, metrics: &Metrics) -> Result<()> { + let (from_client, to_client) = tokio::io::copy_bidirectional(client, &mut upstream).await?; + metrics.add_bytes_received(from_client); + metrics.add_bytes_sent(to_client); Ok(()) } @@ -322,14 +429,15 @@ async fn handle_http( mode: BalanceMode, use_proxy: bool, direct_target: (&str, u16), -) -> Result { + metrics: &Metrics, +) -> Result { if !use_proxy { let addr = format!("{}:{}", direct_target.0, direct_target.1); let mut direct_stream = match connect_target(&addr).await { Ok(s) => s, Err(_) => { - let _ = write_status(client, 502, "Bad Gateway").await; - return Ok(false); + write_status(client, 502, "Bad Gateway", metrics).await?; + return Ok(ProxyResponse::failure(502)); } }; let no_proxy_auth = None; @@ -348,18 +456,22 @@ async fn handle_http( req_content_length, req_is_chunked, already_in_buf as u64, + metrics, ) .await .is_err() { - return Ok(false); + return Ok(ProxyResponse::failure(502)); } - let upstream_keep_alive = match forward_response(&mut direct_stream, client, method).await { - Ok(v) => v, - Err(_) => return Ok(false), + let response = match forward_response(&mut direct_stream, client, method, metrics).await { + Ok(outcome) => outcome, + Err(_) => return Ok(ProxyResponse::failure(502)), }; - return Ok(client_keep_alive && upstream_keep_alive); + return Ok(ProxyResponse::success( + response.response_status, + client_keep_alive && response.upstream_keep_alive, + )); } let mut tried: Vec = Vec::new(); @@ -369,8 +481,8 @@ async fn handle_http( let (idx, upstream) = match pool.select(mode, &tried) { Some(u) => u, None => { - let _ = write_status(client, 502, "No upstream available").await; - return Ok(false); + write_status(client, 502, "No upstream available", metrics).await?; + return Ok(ProxyResponse::failure(502)); } }; tried.push(idx); @@ -386,8 +498,8 @@ async fn handle_http( upstream.mark_offline(); upstream.record_failure(); if tried.len() >= max_retries { - let _ = write_status(client, 502, "Bad Gateway").await; - return Ok(false); + write_status(client, 502, "Bad Gateway", metrics).await?; + return Ok(ProxyResponse::failure(502)); } continue; } @@ -405,8 +517,8 @@ async fn handle_http( upstream.mark_offline(); upstream.record_failure(); if tried.len() >= max_retries { - let _ = write_status(client, 502, "Bad Gateway").await; - return Ok(false); + write_status(client, 502, "Bad Gateway", metrics).await?; + return Ok(ProxyResponse::failure(502)); } continue; } @@ -420,7 +532,8 @@ async fn handle_http( upstream.mark_offline(); upstream.record_failure(); if tried.len() >= max_retries { - return Ok(false); + write_status(client, 502, "Bad Gateway", metrics).await?; + return Ok(ProxyResponse::failure(502)); } continue; } @@ -431,11 +544,12 @@ async fn handle_http( req_content_length, req_is_chunked, already_in_buf as u64, + metrics, ) .await { warn!(error = %e, "HTTP: body forward error"); - return Ok(false); + return Ok(ProxyResponse::failure(502)); } // --- Stream response back --- @@ -443,22 +557,25 @@ async fn handle_http( upstream .active_conns .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let resp_result = forward_response(&mut up_stream, client, method).await; + let resp_result = forward_response(&mut up_stream, client, method, metrics).await; upstream .active_conns .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); match resp_result { - Ok(upstream_keep_alive) => { + Ok(response) => { upstream.record_success(t0.elapsed().as_millis() as u64); - return Ok(client_keep_alive && upstream_keep_alive); + return Ok(ProxyResponse::success( + response.response_status, + client_keep_alive && response.upstream_keep_alive, + )); } Err(e) => { warn!(upstream = %upstream.config.url, error = %e, "HTTP: response forward failed"); upstream.mark_offline(); upstream.record_failure(); if tried.len() >= max_retries { - return Ok(false); + return Ok(ProxyResponse::failure(502)); } continue; } @@ -482,14 +599,17 @@ async fn forward_body( content_length: Option, is_chunked: bool, already_sent: u64, + metrics: &Metrics, ) -> Result<()> { if let Some(len) = content_length { let remaining = len.saturating_sub(already_sent); if remaining > 0 { - copy_exact(client, upstream, remaining).await?; + let copied = copy_exact(client, upstream, remaining).await?; + metrics.add_bytes_received(copied); } } else if is_chunked { - copy_chunked(client, upstream).await?; + let copied = copy_chunked(client, upstream).await?; + metrics.add_bytes_received(copied); } // else: no body (GET / HEAD / etc.) Ok(()) @@ -504,7 +624,8 @@ async fn forward_response( upstream: &mut TcpStream, client: &mut TcpStream, req_method: &str, -) -> Result { + metrics: &Metrics, +) -> Result { // Read response headers let header_buf = read_headers(upstream) .await? @@ -527,14 +648,17 @@ async fn forward_response( .map(|v| v.to_ascii_lowercase().contains("chunked")) .unwrap_or(false); let upstream_keep_alive = upstream_wants_keep_alive(headers, version); + let response_status = status; // Forward headers verbatim client.write_all(&header_buf[..body_offset]).await?; + metrics.add_bytes_sent(body_offset as u64); // Forward any body bytes that were buffered with the headers let already_in_buf = header_buf.len().saturating_sub(body_offset); if already_in_buf > 0 { client.write_all(&header_buf[body_offset..]).await?; + metrics.add_bytes_sent(already_in_buf as u64); } // Determine whether this response has a body @@ -547,19 +671,28 @@ async fn forward_response( if let Some(len) = resp_content_length { let remaining = len.saturating_sub(already_in_buf as u64); if remaining > 0 { - copy_exact(upstream, client, remaining).await?; + let copied = copy_exact(upstream, client, remaining).await?; + metrics.add_bytes_sent(copied); } } else if resp_is_chunked { - copy_chunked(upstream, client).await?; + let copied = copy_chunked(upstream, client).await?; + metrics.add_bytes_sent(copied); } else { // No Content-Length and not chunked: read until upstream closes. // We must also close the client connection afterwards. - tokio::io::copy(upstream, client).await?; - return Ok(false); + let copied = tokio::io::copy(upstream, client).await?; + metrics.add_bytes_sent(copied); + return Ok(ResponseForwardOutcome { + response_status, + upstream_keep_alive: false, + }); } } - Ok(upstream_keep_alive) + Ok(ResponseForwardOutcome { + response_status, + upstream_keep_alive, + }) } // --------------------------------------------------------------------------- @@ -568,17 +701,19 @@ async fn forward_response( /// Read a raw chunked stream from `src` and write to `dst` until the /// terminal `0\r\n\r\n` chunk. -async fn copy_chunked(src: &mut TcpStream, dst: &mut TcpStream) -> Result<()> { +async fn copy_chunked(src: &mut TcpStream, dst: &mut TcpStream) -> Result { let mut buf = vec![0u8; 8 * 1024]; // We look for the terminal chunk marker in what we forward. // Since this is a relay, we forward bytes verbatim and detect the end. let mut trailer = Vec::new(); + let mut copied = 0; loop { let n = src.read(&mut buf).await?; if n == 0 { break; } dst.write_all(&buf[..n]).await?; + copied += n as u64; // Accumulate last 8 bytes to detect "0\r\n\r\n" trailer.extend_from_slice(&buf[..n]); if trailer.len() > 8 { @@ -588,12 +723,13 @@ async fn copy_chunked(src: &mut TcpStream, dst: &mut TcpStream) -> Result<()> { break; } } - Ok(()) + Ok(copied) } /// Copy exactly `bytes` bytes from `src` to `dst`. -async fn copy_exact(src: &mut TcpStream, dst: &mut TcpStream, mut bytes: u64) -> Result<()> { +async fn copy_exact(src: &mut TcpStream, dst: &mut TcpStream, mut bytes: u64) -> Result { let mut buf = vec![0u8; 8 * 1024]; + let mut copied = 0; while bytes > 0 { let to_read = (buf.len() as u64).min(bytes) as usize; let n = src.read(&mut buf[..to_read]).await?; @@ -601,9 +737,10 @@ async fn copy_exact(src: &mut TcpStream, dst: &mut TcpStream, mut bytes: u64) -> bail!("unexpected EOF: expected {bytes} more bytes"); } dst.write_all(&buf[..n]).await?; + copied += n as u64; bytes -= n as u64; } - Ok(()) + Ok(copied) } /// Read bytes from `stream` into a growing buffer until the HTTP header @@ -908,12 +1045,61 @@ fn upstream_wants_keep_alive(headers: &[httparse::Header<'_>], version: u8) -> b } } -async fn write_status(stream: &mut TcpStream, code: u16, msg: &str) -> Result<()> { - let resp = format!("HTTP/1.1 {code} {msg}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"); +pub async fn reject_connection(mut stream: TcpStream) -> Result<()> { + let resp = build_status_response(503, "Service Unavailable"); stream.write_all(resp.as_bytes()).await?; Ok(()) } +async fn write_status( + stream: &mut TcpStream, + code: u16, + msg: &str, + metrics: &Metrics, +) -> Result<()> { + let resp = build_status_response(code, msg); + metrics.add_bytes_sent(resp.len() as u64); + stream.write_all(resp.as_bytes()).await?; + Ok(()) +} + +fn build_status_response(code: u16, msg: &str) -> String { + format!("HTTP/1.1 {code} {msg}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") +} + +fn request_log_context(buf: &[u8], domain_policy: &DomainPolicyConfig) -> RequestLogContext { + let mut ctx = RequestLogContext { + method: "UNKNOWN".to_string(), + target: "".to_string(), + direct: false, + }; + + let mut raw_headers = [httparse::EMPTY_HEADER; 96]; + let mut req = httparse::Request::new(&mut raw_headers); + let parsed = matches!(req.parse(buf), Ok(httparse::Status::Complete(_))); + if !parsed { + return ctx; + } + + if let Some(method) = req.method { + ctx.method = method.to_string(); + } + if let Some(path) = req.path { + ctx.target = path.to_string(); + ctx.direct = if ctx.method.eq_ignore_ascii_case("CONNECT") { + parse_connect_target(path) + .map(|(host, _)| !should_use_proxy(&host, domain_policy)) + .unwrap_or(false) + } else { + parse_http_target(path, req.headers) + .map(|(host, _)| !should_use_proxy(&host, domain_policy)) + .unwrap_or(false) + }; + } + + ctx +} + #[cfg(test)] mod tests { use super::*; From 16fe02754eae69a4ff2b46064f514acced086a48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=9F=E9=9B=B2=E5=B8=8C=E5=87=AA?= <70424266+MejiroRina@users.noreply.github.com> Date: Mon, 23 Mar 2026 06:46:17 +0800 Subject: [PATCH 11/13] [Chore] Update AGENTS.md and copilot-instructions.md --- .github/copilot-Instructions.md | 27 ------------------ .github/copilot-instructions.md | 49 +++++++++++++++++++++++++++++++++ AGENTS.md | 21 ++++++++++++-- 3 files changed, 68 insertions(+), 29 deletions(-) delete mode 100644 .github/copilot-Instructions.md create mode 100644 .github/copilot-instructions.md diff --git a/.github/copilot-Instructions.md b/.github/copilot-Instructions.md deleted file mode 100644 index 60e6e26..0000000 --- a/.github/copilot-Instructions.md +++ /dev/null @@ -1,27 +0,0 @@ -# copilot-Instructions.md - -## Project coding notes - -- Language: Rust (edition 2021) -- Runtime: Tokio -- Proxy protocol handling is in `src/proxy.rs` -- Configuration is YAML via `yaml_serde` - -## Development expectations - -1. Make surgical changes that directly address the request. -2. Add/adjust tests for changed behavior. -3. Keep docs aligned with user-facing configuration changes. -4. Validate with: - - `cargo test` - - `cargo clippy -- -D warnings` - -## Domain policy expressions - -`domain_policy.domains` currently supports: - -- `domain:example.com` (exact match) -- `suffix:example.com` (domain suffix match) -- `*.example.com` (suffix shorthand) -- `.example.com` (suffix shorthand) -- `example.com` (backward-compatible exact-or-suffix behavior) diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000..d84442f --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,49 @@ +# Copilot Instructions + +## Project coding notes + +- Language: Rust (edition 2021) +- Runtime: Tokio +- Proxy protocol handling is in `src/proxy.rs` +- Configuration is YAML via `yaml_serde` +- Admin endpoints and shared metrics are in `src/admin.rs` +- Integration coverage lives in `tests/integration_smoke.rs` +- Container smoke coverage lives in `scripts/docker-smoke.sh` + +## Project status + +- This project is in a completed-v1 / release-ready state. +- Default to preserving current runtime behavior unless a task explicitly asks for a change. +- Regressions in proxy forwarding, response codes, admin APIs, metrics, graceful shutdown, request timeouts, connection limiting, and Docker deployment should be treated as important. + +## Development expectations + +1. Make surgical changes that directly address the request. +2. Add/adjust tests for changed behavior. +3. Keep docs aligned with user-facing configuration changes. +4. Validate with: + - `cargo test` + - `cargo clippy -- -D warnings` + - `cargo fmt -- --check` + - `./scripts/docker-smoke.sh` when Docker-related or deployment behavior changes + - `./target/debug/http-proxy-lb --config config.yaml --check` when config behavior changes + +## Validation focus + +- Prefer integration coverage for: + - config validation + - real HTTP status propagation + - request timeout behavior + - connection limiting behavior + - admin metrics/status counters +- Keep README, `AGENTS.md`, and this file aligned when the project’s validation workflow changes. + +## Domain policy expressions + +`domain_policy.domains` currently supports: + +- `domain:example.com` (exact match) +- `suffix:example.com` (domain suffix match) +- `*.example.com` (suffix shorthand) +- `.example.com` (suffix shorthand) +- `example.com` (backward-compatible exact-or-suffix behavior) diff --git a/AGENTS.md b/AGENTS.md index a89c309..7901d61 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -2,7 +2,13 @@ ## Purpose -This repository contains `http-proxy-lb`, a Rust HTTP relay proxy with upstream load balancing, health checks, hot reload, and domain policy routing. +This repository contains `http-proxy-lb`, a Rust HTTP relay proxy with upstream load balancing, health checks, hot reload, domain policy routing, admin endpoints, Prometheus metrics, graceful shutdown, request limits, and container deployment assets. + +## Status + +- The repository is in a near-release / completed-v1 state. +- Prefer preserving the current behavior and validation baseline unless the task explicitly requires a change. +- Treat regressions in proxy behavior, admin endpoints, metrics, timeout handling, connection limiting, or Docker deployment as high priority. ## Scope @@ -14,12 +20,23 @@ This repository contains `http-proxy-lb`, a Rust HTTP relay proxy with upstream - Run tests: `cargo test` - Run lint checks: `cargo clippy -- -D warnings` -- Format code: `cargo fmt` +- Format code: `cargo fmt -- --check` +- Run container smoke test: `./scripts/docker-smoke.sh` +- Validate config file: `./target/debug/http-proxy-lb --config config.yaml --check` ## Code structure - `src/config.rs`: configuration model and YAML loading +- `src/admin.rs`: admin server, `/metrics`, `/status`, `/health`, and shared metrics - `src/upstream.rs`: upstream entry state + pool selection/reload - `src/health.rs`: active health checking logic - `src/proxy.rs`: CONNECT and HTTP forwarding logic - `src/main.rs`: startup, accept loop, background tasks +- `tests/integration_smoke.rs`: end-to-end integration tests for config validation, timeouts, metrics, and connection limits +- `scripts/docker-smoke.sh`: local Docker smoke test helper + +## Expectations + +- Keep changes minimal and targeted. +- Update tests and docs for any user-visible behavior change. +- Prefer keeping `cargo test`, `cargo clippy -- -D warnings`, `cargo fmt -- --check`, and `./scripts/docker-smoke.sh` green before considering work complete. From 6a93533455fc6cd85a1c60ab72770ec78a213e87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=9F=E9=9B=B2=E5=B8=8C=E5=87=AA?= <70424266+MejiroRina@users.noreply.github.com> Date: Mon, 23 Mar 2026 06:59:29 +0800 Subject: [PATCH 12/13] [Chore] Add CI --- .github/workflows/ci.yml | 88 ++++++++++++++++++++++ .github/workflows/release.yml | 137 ++++++++++++++++++++++++++++++++++ 2 files changed, 225 insertions(+) create mode 100644 .github/workflows/ci.yml create mode 100644 .github/workflows/release.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..8349f5d --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,88 @@ +name: CI + +on: + push: + pull_request: + +permissions: + contents: read + +env: + CARGO_TERM_COLOR: always + +jobs: + validate: + name: Validate + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + components: clippy, rustfmt + + - name: Cache Rust artifacts + uses: Swatinem/rust-cache@v2 + + - name: Format check + run: cargo fmt -- --check + + - name: Clippy + run: cargo clippy --all-targets -- -D warnings + + - name: Test + run: cargo test --locked + + binary-build: + name: Binary Build (${{ matrix.os }}) + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: + - ubuntu-latest + - macos-latest + - windows-latest + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache Rust artifacts + uses: Swatinem/rust-cache@v2 + + - name: Build release binary + run: cargo build --release --locked + + - name: Upload binary artifact (Unix) + if: runner.os != 'Windows' + uses: actions/upload-artifact@v6 + with: + name: http-proxy-lb-${{ runner.os }} + path: target/release/http-proxy-lb + + - name: Upload binary artifact (Windows) + if: runner.os == 'Windows' + uses: actions/upload-artifact@v6 + with: + name: http-proxy-lb-${{ runner.os }} + path: target/release/http-proxy-lb.exe + + docker-build: + name: Docker Build + runs-on: ubuntu-latest + needs: validate + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Run Docker smoke test + run: ./scripts/docker-smoke.sh diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..376de82 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,137 @@ +name: Release + +on: + push: + tags: + - "v*" + +permissions: + contents: write + packages: write + +env: + CARGO_TERM_COLOR: always + BINARY_NAME: http-proxy-lb + +jobs: + release-binaries: + name: Release Binary (${{ matrix.name }}) + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + include: + - name: linux-x86_64 + os: ubuntu-latest + target: x86_64-unknown-linux-gnu + archive_ext: tar.gz + - name: macos-aarch64 + os: macos-14 + target: aarch64-apple-darwin + archive_ext: tar.gz + - name: windows-x86_64 + os: windows-latest + target: x86_64-pc-windows-msvc + archive_ext: zip + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + with: + targets: ${{ matrix.target }} + + - name: Cache Rust artifacts + uses: Swatinem/rust-cache@v2 + + - name: Build release binary + run: cargo build --release --locked --target ${{ matrix.target }} + + - name: Package archive (Unix) + if: runner.os != 'Windows' + shell: bash + run: | + set -euo pipefail + pkg_dir="${BINARY_NAME}-${{ github.ref_name }}-${{ matrix.target }}" + mkdir -p "${pkg_dir}" + cp "target/${{ matrix.target }}/release/${BINARY_NAME}" "${pkg_dir}/" + cp README.md LICENSE "${pkg_dir}/" + tar -czf "${pkg_dir}.tar.gz" "${pkg_dir}" + + - name: Package archive (Windows) + if: runner.os == 'Windows' + shell: pwsh + run: | + $pkgDir = "${env:BINARY_NAME}-${{ github.ref_name }}-${{ matrix.target }}" + New-Item -ItemType Directory -Path $pkgDir | Out-Null + Copy-Item "target/${{ matrix.target }}/release/${env:BINARY_NAME}.exe" "$pkgDir/" + Copy-Item "README.md","LICENSE" "$pkgDir/" + Compress-Archive -Path "$pkgDir/*" -DestinationPath "$pkgDir.zip" + + - name: Upload packaged artifact + uses: actions/upload-artifact@v6 + with: + name: release-${{ matrix.target }} + path: | + http-proxy-lb-${{ github.ref_name }}-${{ matrix.target }}.tar.gz + http-proxy-lb-${{ github.ref_name }}-${{ matrix.target }}.zip + if-no-files-found: ignore + + github-release: + name: GitHub Release + runs-on: ubuntu-latest + needs: release-binaries + + steps: + - name: Download packaged artifacts + uses: actions/download-artifact@v5 + with: + path: dist + pattern: release-* + merge-multiple: true + + - name: Publish GitHub Release + uses: softprops/action-gh-release@v2 + with: + files: | + dist/*.tar.gz + dist/*.zip + generate_release_notes: true + + docker-release: + name: Docker Release + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to GHCR + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract Docker metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ghcr.io/${{ github.repository }} + tags: | + type=ref,event=tag + + - name: Build and push Docker image + uses: docker/build-push-action@v6 + with: + context: . + file: ./Dockerfile + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} From 1da24b2b02fd29bf96885f88440a3d88ecd1235f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=9F=E9=9B=B2=E5=B8=8C=E5=87=AA?= <70424266+MejiroRina@users.noreply.github.com> Date: Mon, 23 Mar 2026 07:01:28 +0800 Subject: [PATCH 13/13] [Chore] Add Docker smoke script --- scripts/docker-smoke.sh | 54 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100755 scripts/docker-smoke.sh diff --git a/scripts/docker-smoke.sh b/scripts/docker-smoke.sh new file mode 100755 index 0000000..7c1d5aa --- /dev/null +++ b/scripts/docker-smoke.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +TMP_DIR="$(mktemp -d)" +IMAGE_TAG="http-proxy-lb:smoke" +CONTAINER_NAME="http-proxy-lb-smoke" + +cleanup() { + docker rm -f "${CONTAINER_NAME}" >/dev/null 2>&1 || true + rm -rf "${TMP_DIR}" +} + +trap cleanup EXIT + +cat >"${TMP_DIR}/config.yaml" <<'EOF' +listen: "0.0.0.0:8080" +admin_listen: "0.0.0.0:9090" +mode: round_robin +reload_interval_secs: 0 +access_log: true +health_check: + interval_secs: 30 + timeout_secs: 2 +limits: + max_connections: 32 + request_timeout_secs: 5 + shutdown_timeout_secs: 5 +domain_policy: + mode: whitelist + domains: [] +upstream: [] +EOF + +docker build -t "${IMAGE_TAG}" "${ROOT_DIR}" + +docker run -d \ + --name "${CONTAINER_NAME}" \ + -p 18080:8080 \ + -p 19090:9090 \ + -v "${TMP_DIR}/config.yaml:/etc/http-proxy-lb/config.yaml:ro" \ + "${IMAGE_TAG}" >/dev/null + +for _ in $(seq 1 20); do + if curl -fsS "http://127.0.0.1:19090/health" >/dev/null 2>&1; then + break + fi + sleep 1 +done + +curl -fsS "http://127.0.0.1:19090/health" +echo +curl -fsS "http://127.0.0.1:19090/status" +echo