Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 227 additions & 14 deletions crates/coven-cli/src/daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,7 @@ pub fn serve_next_tcp_connection(
status,
runtime,
Some(MAX_TCP_BODY_BYTES),
true,
)
}

Expand Down Expand Up @@ -1034,20 +1035,35 @@ fn handle_http_stream<R, W>(
status: Option<DaemonStatus>,
runtime: &dyn SessionRuntime,
max_body_bytes: Option<usize>,
enforce_loopback_guard: bool,
) -> Result<()>
where
R: Read,
W: Write,
{
let mut reader = BufReader::new(read);
let request_line = read_http_request_line(&mut reader)?;
let content_length = read_http_headers(&mut reader)?;
let headers = read_http_headers(&mut reader)?;
// On the TCP transport (loopback-only), defend against browser-driven CSRF
// and DNS-rebinding: a real CLI/proxy client never sends a cross-origin
// Origin, and a rebinding attack arrives with a non-loopback Host. The Unix
// socket is filesystem-gated and skips this.
if enforce_loopback_guard {
if !host_is_loopback(headers.host.as_deref()) {
return write_forbidden(&mut write, "Host header must be a loopback address.");
}
if let Some(origin) = headers.origin.as_deref() {
if !origin_is_loopback(origin) {
return write_forbidden(&mut write, "Cross-origin requests are not allowed.");
}
}
}
if let Some(max) = max_body_bytes {
if content_length > max {
if headers.content_length > max {
return write_payload_too_large(&mut write, max);
}
}
let body = read_http_body(&mut reader, content_length)?;
let body = read_http_body(&mut reader, headers.content_length)?;
let (method, path) = parse_request_line(&request_line)?;
let response = crate::api::handle_request_with_runtime(
method,
Expand Down Expand Up @@ -1088,6 +1104,59 @@ fn write_payload_too_large<W: Write>(write: &mut W, max: usize) -> Result<()> {
Ok(())
}

#[cfg(unix)]
fn host_is_loopback(host: Option<&str>) -> bool {
match host {
Some(h) => is_loopback_host(strip_port(h.trim())),
None => false,
}
}

#[cfg(unix)]
fn origin_is_loopback(origin: &str) -> bool {
match origin.trim().split_once("://") {
Some((_scheme, rest)) => is_loopback_host(strip_port(rest)),
None => false,
}
}

#[cfg(unix)]
fn strip_port(authority: &str) -> &str {
if let Some(rest) = authority.strip_prefix('[') {
// IPv6 literal like [::1]:8080 -> ::1
return rest.split(']').next().unwrap_or(rest);
}
authority.split(':').next().unwrap_or(authority)
}

#[cfg(unix)]
fn is_loopback_host(host: &str) -> bool {
// Parse as an IP and ask the address itself — never a string prefix. A prefix
// test like `starts_with("127.")` would also accept attacker hostnames such as
// `127.evil.com`, defeating the DNS-rebinding guard this function backs.
if host == "localhost" {
return true;
}
host.parse::<std::net::IpAddr>()
.map(|ip| ip.is_loopback())
.unwrap_or(false)
}

#[cfg(unix)]
fn write_forbidden<W: Write>(write: &mut W, reason: &str) -> Result<()> {
let body =
format!("{{\"ok\":false,\"error\":{{\"code\":\"forbidden\",\"message\":\"{reason}\"}}}}");
let http = format!(
"HTTP/1.1 403 Forbidden\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
);
write
.write_all(http.as_bytes())
.context("failed to write 403 response")?;
Ok(())
}

#[cfg(unix)]
pub fn serve_next_connection(
listener: &UnixListener,
Expand All @@ -1101,7 +1170,7 @@ pub fn serve_next_connection(
let read = stream.try_clone().context("failed to clone Unix stream")?;
// Unix socket has no body cap — only local processes can reach it and the
// socket permission bits already gate access.
handle_http_stream(read, stream, coven_home, status, runtime, None)
handle_http_stream(read, stream, coven_home, status, runtime, None, false)
}

fn http_reason_phrase(status: u16) -> &'static str {
Expand Down Expand Up @@ -1130,8 +1199,19 @@ fn read_http_request_line<R: BufRead>(reader: &mut R) -> Result<String> {
}

#[cfg(unix)]
fn read_http_headers<R: BufRead>(reader: &mut R) -> Result<usize> {
let mut content_length = 0;
struct ParsedHeaders {
content_length: usize,
host: Option<String>,
origin: Option<String>,
}

#[cfg(unix)]
fn read_http_headers<R: BufRead>(reader: &mut R) -> Result<ParsedHeaders> {
let mut headers = ParsedHeaders {
content_length: 0,
host: None,
origin: None,
};
let mut header = String::new();
loop {
header.clear();
Expand All @@ -1142,15 +1222,18 @@ fn read_http_headers<R: BufRead>(reader: &mut R) -> Result<usize> {
break;
}
if let Some((name, value)) = header.split_once(':') {
let name = name.trim();
let value = value.trim();
if name.eq_ignore_ascii_case("content-length") {
content_length = value
.trim()
.parse()
.context("invalid Content-Length header")?;
headers.content_length = value.parse().context("invalid Content-Length header")?;
} else if name.eq_ignore_ascii_case("host") {
headers.host = Some(value.to_string());
} else if name.eq_ignore_ascii_case("origin") {
headers.origin = Some(value.to_string());
}
}
}
Ok(content_length)
Ok(headers)
}

#[cfg(unix)]
Expand Down Expand Up @@ -1400,8 +1483,16 @@ mod tests {
let mut stream = Cursor::new(Vec::from(&request[..]));
let mut output: Vec<u8> = Vec::new();
let runtime = NoopSessionRuntime;
handle_http_stream(&mut stream, &mut output, temp.path(), None, &runtime, None)
.expect("handle ok");
handle_http_stream(
&mut stream,
&mut output,
temp.path(),
None,
&runtime,
None,
false,
)
.expect("handle ok");
let response = String::from_utf8(output).expect("utf8");
assert!(response.starts_with("HTTP/1.1 200 OK"), "got: {response}");
assert!(response.contains("\"apiVersion\""), "got: {response}");
Expand Down Expand Up @@ -1430,6 +1521,7 @@ mod tests {
None,
&runtime,
Some(MAX_TCP_BODY_BYTES),
false,
)
.expect("handle ok");
let response = String::from_utf8(output).expect("utf8");
Expand All @@ -1440,6 +1532,125 @@ mod tests {
assert!(response.contains("payload_too_large"), "got: {response}");
}

#[cfg(unix)]
#[test]
fn handle_http_stream_guard_blocks_cross_origin() {
use crate::api::NoopSessionRuntime;
use std::io::Cursor;
let temp = tempfile::tempdir().expect("tempdir");
ensure_private_coven_home(temp.path()).expect("ensure home");
let request = b"GET /api/v1/health HTTP/1.1\r\nHost: 127.0.0.1:3000\r\nOrigin: http://evil.example\r\n\r\n";
let mut stream = Cursor::new(Vec::from(&request[..]));
let mut output: Vec<u8> = Vec::new();
handle_http_stream(
&mut stream,
&mut output,
temp.path(),
None,
&NoopSessionRuntime,
Some(MAX_TCP_BODY_BYTES),
true,
)
.expect("handle ok");
let response = String::from_utf8(output).expect("utf8");
assert!(
response.starts_with("HTTP/1.1 403 Forbidden"),
"got: {response}"
);
}

#[cfg(unix)]
#[test]
fn handle_http_stream_guard_blocks_foreign_host() {
use crate::api::NoopSessionRuntime;
use std::io::Cursor;
let temp = tempfile::tempdir().expect("tempdir");
ensure_private_coven_home(temp.path()).expect("ensure home");
let request = b"GET /api/v1/health HTTP/1.1\r\nHost: evil.example\r\n\r\n";
let mut stream = Cursor::new(Vec::from(&request[..]));
let mut output: Vec<u8> = Vec::new();
handle_http_stream(
&mut stream,
&mut output,
temp.path(),
None,
&NoopSessionRuntime,
Some(MAX_TCP_BODY_BYTES),
true,
)
.expect("handle ok");
let response = String::from_utf8(output).expect("utf8");
assert!(
response.starts_with("HTTP/1.1 403 Forbidden"),
"got: {response}"
);
}

#[cfg(unix)]
#[test]
fn is_loopback_host_accepts_only_real_loopback_addresses() {
// Real loopback: the whole 127.0.0.0/8, ::1, and the localhost name.
assert!(is_loopback_host("127.0.0.1"));
assert!(is_loopback_host("127.0.0.2"));
assert!(is_loopback_host("::1"));
assert!(is_loopback_host("localhost"));
// Hostnames that merely *start with* "127." must NOT pass: a DNS-rebinding
// attacker can register 127.evil.com -> 127.0.0.1 and would otherwise slip
// through a string-prefix check and defeat the loopback guard.
assert!(!is_loopback_host("127.evil.com"));
assert!(!is_loopback_host("127001.example.com"));
assert!(!is_loopback_host("evil.example"));
assert!(!is_loopback_host(""));
}

#[cfg(unix)]
#[test]
fn handle_http_stream_guard_allows_loopback_origin() {
use crate::api::NoopSessionRuntime;
use std::io::Cursor;
let temp = tempfile::tempdir().expect("tempdir");
ensure_private_coven_home(temp.path()).expect("ensure home");
let request = b"GET /api/v1/health HTTP/1.1\r\nHost: localhost:3000\r\nOrigin: http://localhost:3000\r\n\r\n";
let mut stream = Cursor::new(Vec::from(&request[..]));
let mut output: Vec<u8> = Vec::new();
handle_http_stream(
&mut stream,
&mut output,
temp.path(),
None,
&NoopSessionRuntime,
Some(MAX_TCP_BODY_BYTES),
true,
)
.expect("handle ok");
let response = String::from_utf8(output).expect("utf8");
assert!(response.starts_with("HTTP/1.1 200 OK"), "got: {response}");
}

#[cfg(unix)]
#[test]
fn handle_http_stream_unix_path_ignores_origin() {
use crate::api::NoopSessionRuntime;
use std::io::Cursor;
let temp = tempfile::tempdir().expect("tempdir");
ensure_private_coven_home(temp.path()).expect("ensure home");
let request = b"GET /api/v1/health HTTP/1.1\r\nHost: evil.example\r\nOrigin: http://evil.example\r\n\r\n";
let mut stream = Cursor::new(Vec::from(&request[..]));
let mut output: Vec<u8> = Vec::new();
handle_http_stream(
&mut stream,
&mut output,
temp.path(),
None,
&NoopSessionRuntime,
None,
false,
)
.expect("handle ok");
let response = String::from_utf8(output).expect("utf8");
assert!(response.starts_with("HTTP/1.1 200 OK"), "got: {response}");
}

#[cfg(unix)]
#[test]
fn bind_tcp_listener_serves_health_over_tcp() {
Expand All @@ -1462,7 +1673,9 @@ mod tests {
.set_read_timeout(Some(std::time::Duration::from_secs(5)))
.expect("read timeout");
client
.write_all(b"GET /api/v1/health HTTP/1.1\r\nHost: x\r\nContent-Length: 0\r\n\r\n")
.write_all(
b"GET /api/v1/health HTTP/1.1\r\nHost: 127.0.0.1\r\nContent-Length: 0\r\n\r\n",
)
.expect("write request");
let mut response = String::new();
client.read_to_string(&mut response).expect("read response");
Expand Down