Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,20 @@ linters:
deny:
- pkg: github.com/MustardSeedNetworks/stem/internal/api
desc: "tlsutil is a leaf — cert provisioning, ACME, and TLS config take no api types, so it never imports the api transport layer"
# cors is a leaf of internal/api (ADR-0011): the origin-classification
# logic (localhost / same-origin / RFC1918) depends only on stdlib
# (net/url, strings) — never on the api transport layer itself. This
# depguard rule enforces that boundary so a future accidental upward
# import fails CI rather than silently coupling the leaf back into the
# transport layer. Test files are excluded: the _test.go external
# package imports the leaf itself, which is expected.
"api-cors-isolated":
files:
- "**/internal/api/cors/**"
- "!$test"
deny:
- pkg: github.com/MustardSeedNetworks/stem/internal/api
desc: "cors is a leaf — origin classification takes plain strings, so it never imports the api transport layer; wire the middleware at the api layer"

embeddedstructfieldcheck:
# Checks that sync.Mutex and sync.RWMutex are not used as embedded fields.
Expand Down
20 changes: 17 additions & 3 deletions docs/adr/0011-internal-api-sub-package-decomposition.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ The api transport layer wires leaves at construction time; no leaf knows about
|-------|---------|-----|
| Rate limiter | `internal/api/ratelimit` | #451 |
| SSE broadcaster | `internal/api/sse` | #452 |
| TLS utilities | `internal/api/tlsutil` | this ADR |
| TLS utilities | `internal/api/tlsutil` | #453 |
| CORS origin classification | `internal/api/cors` | this ADR |

### SSE slice (this ADR)

Expand Down Expand Up @@ -71,16 +72,29 @@ TLS were rehomed to the api layer rather than dragged into the leaf:
`acmeReadHeaderTimeoutSec` (transport-layer challenge-server timeout →
`server.go`).

### CORS slice (this ADR)

`internal/api/cors` holds the Origin-header classification used by the CORS
policy: `IsLocalhostOrigin`, `IsSameOrigin`, and `IsRFC1918Origin` (with the
strict, complete-IP-structure validation that rejects bypass tricks like
`localhost.evil.com` and `192.168.1.1.evil.com`). It depends only on stdlib
(`net/url`, `strings`).

The HTTP middleware that consumes the classifiers (`corsMiddleware`), the opt-in
env read (`corsAllowPrivateEnabled`, which logs the credentialed-LAN-origin
warning), and the response-header wiring stay in `internal/api`.

### Future slices (candidates)

| Concern | Notes |
|---------|-------|
| CORS logic | RFC 1918 origin validation |
| Handler grouping | the 15 `handlers_*.go` are Server-coupled transport code, not stdlib leaves — needs a naming/grouping decision rather than the leaf recipe |

## Consequences

- The leaf boundary is statically enforced by depguard (`api-sse-isolated`,
`api-ratelimit-isolated`, `api-tlsutil-isolated` rules in `.golangci.yml`).
`api-ratelimit-isolated`, `api-tlsutil-isolated`, `api-cors-isolated` rules in
`.golangci.yml`).
- `go vet` + `golangci-lint` catch upward imports at CI time.
- `internal/api` package size decreases incrementally with each slice.
- No behaviour change: endpoints, event types, and publish sites are identical.
203 changes: 203 additions & 0 deletions internal/api/cors/cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
// SPDX-License-Identifier: BUSL-1.1

// Package cors classifies HTTP Origin header values for the api transport
// layer's CORS policy: localhost detection, same-origin matching, and RFC 1918
// private-network validation. The classification is deliberately strict —
// validating the complete IP structure — to prevent CORS-bypass tricks such as
// "localhost.evil.com" or "192.168.1.1.evil.com".
//
// It is a leaf of internal/api (ADR-0011): it depends only on the standard
// library (net/url, strings) — never on the api transport layer itself. The
// boundary is enforced by depguard (api-cors-isolated). The HTTP middleware
// that consumes these classifiers (corsMiddleware), the opt-in env read
// (corsAllowPrivateEnabled), and the response-header wiring stay in the api
// transport layer.
package cors

import (
"net/url"
"strings"
)

// RFC 1918 validation constants.
const (
// ipPartsClassC is the expected number of IP parts for Class C address validation.
ipPartsClassC = 2

// ipPartsClassAB is the expected number of IP parts for Class A/B address validation.
ipPartsClassAB = 3

// decimalParseBase is the base for decimal digit parsing.
decimalParseBase = 10

// maxIPOctetValue is the maximum valid value for an IP address octet (255).
maxIPOctetValue = 255

// classBMinOctet is the minimum second octet for 172.x.x.x private range.
classBMinOctet = 16

// classBMaxOctet is the maximum second octet for 172.x.x.x private range.
classBMaxOctet = 31
)

// IsLocalhostOrigin validates that the origin is actually localhost.
// Prevents CORS bypass via origins like "localhost.evil.com".
func IsLocalhostOrigin(origin string) bool {
u, err := url.Parse(origin)
if err != nil {
return false
}
host := u.Hostname()
return host == "localhost" || host == "127.0.0.1" || host == "::1"
}

// IsSameOrigin checks if the Origin header matches the request's Host.
// This allows browsers to access the server from its actual address (e.g., 10.0.0.210:8444).
func IsSameOrigin(origin string, requestHost string) bool {
u, err := url.Parse(origin)
if err != nil {
return false
}
// Compare origin host:port with request host.
originHost := u.Host // Includes port if present.
return originHost == requestHost
}

// IsRFC1918Origin checks if the origin is an RFC 1918 private network address.
// Ported from Seed for CORS validation - allows connections from private networks.
//
// Allowed addresses:
// - Class A private: 10.0.0.0/8 (10.x.x.x)
// - Class B private: 172.16.0.0/12 (172.16.x.x through 172.31.x.x)
// - Class C private: 192.168.0.0/16 (192.168.x.x)
//
// Uses proper IP validation to prevent subdomain bypass attacks.
// Rejects malicious origins like "http://192.168.1.1.evil.com".
func IsRFC1918Origin(origin string) bool {
// Reject null origin
if origin == "null" {
return false
}

u, err := url.Parse(origin)
if err != nil {
return false
}

host := u.Hostname()
if host == "" {
return false
}

// Check for RFC 1918 private network ranges
return isPrivateNetworkAddress(host)
}

// isPrivateNetworkAddress checks if the host is an RFC 1918 private network address.
// This prevents subdomain attacks like "192.168.1.1.evil.com" by validating
// the complete IP address structure.
func isPrivateNetworkAddress(host string) bool {
// Class C: 192.168.0.0/16
if strings.HasPrefix(host, "192.168.") {
return isValidClassCAddress(host)
}

// Class A: 10.0.0.0/8
if strings.HasPrefix(host, "10.") {
return isValidClassAAddress(host)
}

// Class B: 172.16.0.0/12 (172.16.0.0 - 172.31.255.255)
if strings.HasPrefix(host, "172.") {
return isValidClassBAddress(host)
}

return false
}

// isValidClassCAddress validates a 192.168.x.x address.
// Returns true if the host is a valid Class C private address.
func isValidClassCAddress(host string) bool {
remainder := host[8:] // After "192.168."
// Should be X.Y where X and Y are 0-255
parts := strings.Split(remainder, ".")
if len(parts) != ipPartsClassC {
return false
}
return isValidIPOctet(parts[0]) && isValidIPOctet(parts[1])
}

// isValidClassAAddress validates a 10.x.x.x address.
// Returns true if the host is a valid Class A private address.
func isValidClassAAddress(host string) bool {
remainder := host[3:] // After "10."
parts := strings.Split(remainder, ".")
if len(parts) != ipPartsClassAB {
return false
}
return isValidIPOctet(parts[0]) && isValidIPOctet(parts[1]) && isValidIPOctet(parts[2])
}

// isValidClassBAddress validates a 172.16-31.x.x address.
// Returns true if the host is a valid Class B private address (172.16.0.0/12).
func isValidClassBAddress(host string) bool {
remainder := host[4:] // After "172."
parts := strings.Split(remainder, ".")
if len(parts) != ipPartsClassAB {
return false
}

// Validate and parse second octet to verify range 16-31
secondOctet, ok := parseOctetInRange(parts[0], classBMinOctet, classBMaxOctet)
if !ok || secondOctet < classBMinOctet || secondOctet > classBMaxOctet {
return false
}

return isValidIPOctet(parts[1]) && isValidIPOctet(parts[2])
}

// parseOctetInRange parses an octet string and checks if it's within the given range.
// Returns the parsed value and true if valid, 0 and false otherwise.
func parseOctetInRange(s string, minVal, maxVal int) (int, bool) {
if s == "" || len(s) > 3 {
return 0, false
}

val := 0
for _, c := range s {
if c < '0' || c > '9' {
return 0, false
}
val = val*decimalParseBase + int(c-'0')
if val > maxIPOctetValue {
return 0, false
}
}

if val < minVal || val > maxVal {
return val, false
}

return val, true
}

// isValidIPOctet checks if a string is a valid IP octet (0-255).
// Helper function for proper IP validation.
func isValidIPOctet(s string) bool {
if s == "" || len(s) > 3 {
return false
}

val := 0
for _, c := range s {
if c < '0' || c > '9' {
return false
}
val = val*decimalParseBase + int(c-'0')
if val > maxIPOctetValue {
return false
}
}

return true
}
113 changes: 113 additions & 0 deletions internal/api/cors/cors_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// SPDX-License-Identifier: BUSL-1.1

package cors

import "testing"

// These tests live in the internal test package so they can exercise the
// unexported IP-validation helpers directly.

// TestIsPrivateNetworkAddress tests the private network address validation helper.
func TestIsPrivateNetworkAddress(t *testing.T) {
tests := []struct {
name string
host string
want bool
}{
// Class C.
{"class C valid", "192.168.1.1", true},
{"class C zero", "192.168.0.0", true},
{"class C max", "192.168.255.255", true},

// Class A.
{"class A valid", "10.0.0.1", true},
{"class A zero", "10.0.0.0", true},
{"class A max", "10.255.255.255", true},

// Class B.
{"class B 172.16", "172.16.0.1", true},
{"class B 172.31", "172.31.255.255", true},
{"class B 172.20", "172.20.100.50", true},

// Invalid.
{"class B 172.15 invalid", "172.15.0.1", false},
{"class B 172.32 invalid", "172.32.0.1", false},
{"public IP", "8.8.8.8", false},
{"localhost", "127.0.0.1", false},
{"localhost name", "localhost", false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isPrivateNetworkAddress(tt.host)
if got != tt.want {
t.Errorf("isPrivateNetworkAddress(%q) = %v, want %v", tt.host, got, tt.want)
}
})
}
}

// TestIsValidIPOctet tests the IP octet validation helper.
func TestIsValidIPOctet(t *testing.T) {
tests := []struct {
name string
octet string
want bool
}{
{"zero", "0", true},
{"single digit", "5", true},
{"double digit", "42", true},
{"triple digit", "255", true},
{"max valid", "255", true},
{"min valid", "0", true},

// Invalid.
{"too large", "256", false},
{"way too large", "999", false},
{"empty", "", false},
{"negative", "-1", false},
{"letters", "abc", false},
{"mixed", "12a", false},
{"too long", "1234", false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isValidIPOctet(tt.octet)
if got != tt.want {
t.Errorf("isValidIPOctet(%q) = %v, want %v", tt.octet, got, tt.want)
}
})
}
}

// TestParseOctetInRange tests the octet parsing with range validation.
func TestParseOctetInRange(t *testing.T) {
tests := []struct {
name string
s string
minVal int
maxVal int
want int
wantOk bool
}{
{"in range", "20", 16, 31, 20, true},
{"at min", "16", 16, 31, 16, true},
{"at max", "31", 16, 31, 31, true},
{"below min", "15", 16, 31, 15, false},
{"above max", "32", 16, 31, 32, false},
{"empty string", "", 0, 255, 0, false},
{"invalid chars", "abc", 0, 255, 0, false},
{"too long", "1234", 0, 255, 0, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := parseOctetInRange(tt.s, tt.minVal, tt.maxVal)
if got != tt.want || ok != tt.wantOk {
t.Errorf("parseOctetInRange(%q, %d, %d) = (%d, %v), want (%d, %v)",
tt.s, tt.minVal, tt.maxVal, got, ok, tt.want, tt.wantOk)
}
})
}
}
Loading
Loading