Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions distribution/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package distribution

import (
"context"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -126,6 +128,7 @@ type ServerConfig struct {
Addr string
WebDir string
Cert *certs.CertInfo
TLSConfig *tls.Config
StreamLister StreamLister
IngestLookup IngestLookup
SRTPull SRTPullFunc
Expand Down Expand Up @@ -180,8 +183,8 @@ type Server struct {
// NewServer creates a distribution Server with the given configuration.
// It returns an error if required fields are missing.
func NewServer(config ServerConfig) (*Server, error) {
if config.Cert == nil {
return nil, errors.New("distribution: Cert is required")
if config.Cert == nil && config.TLSConfig == nil {
return nil, errors.New("distribution: either Cert or TLSConfig is required")
}
if config.Addr == "" {
return nil, errors.New("distribution: Addr is required")
Expand Down Expand Up @@ -327,8 +330,13 @@ func (s *Server) Start(ctx context.Context) error {
wtMux.HandleFunc("/moq", s.handleMoQ)
s.registerAPIRoutes(wtMux)

tlsConfig := &tls.Config{
Certificates: []tls.Certificate{s.config.Cert.TLSCert},
var tlsConfig *tls.Config
if s.config.TLSConfig != nil {
tlsConfig = s.config.TLSConfig
} else {
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{s.config.Cert.TLSCert},
}
}

s.wtSrv = &webtransport.Server{
Expand Down Expand Up @@ -518,12 +526,42 @@ func (s *Server) handleStreamDebug(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handleCertHash(w http.ResponseWriter, _ *http.Request) {
var hash string
if s.config.Cert != nil {
hash = s.config.Cert.FingerprintBase64()
} else {
cert, err := s.tlsCertificate()
if err != nil {
slog.Error("failed to get TLS certificate", "error", err)
writeError(w, http.StatusInternalServerError, "no TLS certificate available")
return
}
fp := sha256.Sum256(cert.Certificate[0])
hash = base64.StdEncoding.EncodeToString(fp[:])
}
writeJSON(w, http.StatusOK, certHashResponse{
Hash: s.config.Cert.FingerprintBase64(),
Hash: hash,
Addr: s.config.Addr,
})
}

func (s *Server) tlsCertificate() (*tls.Certificate, error) {
tc := s.config.TLSConfig
if tc.GetCertificate != nil {
cert, err := tc.GetCertificate(&tls.ClientHelloInfo{})
if err != nil {
return nil, err
}
if cert != nil {
return cert, nil
}
}
if len(tc.Certificates) > 0 {
return &tc.Certificates[0], nil
}
return nil, errors.New("no certificate in TLSConfig")
}

func (s *Server) handleSRTPullOptions(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
Expand Down
104 changes: 101 additions & 3 deletions distribution/server_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package distribution

import (
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -111,6 +114,87 @@ func TestHandleCertHash(t *testing.T) {
}
}

func TestHandleCertHashWithTLSConfig(t *testing.T) {
t.Parallel()

cert, err := certs.Generate(24 * 60 * 60 * 1e9)
if err != nil {
t.Fatalf("certs.Generate: %v", err)
}

tlsCfg := &tls.Config{Certificates: []tls.Certificate{cert.TLSCert}}
srv, err := NewServer(ServerConfig{
Addr: ":0",
TLSConfig: tlsCfg,
})
if err != nil {
t.Fatalf("NewServer: %v", err)
}
handler := srv.APIHandler()

req := httptest.NewRequest("GET", "/api/cert-hash", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}

var resp certHashResponse
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("decode: %v", err)
}

// Verify the hash matches what we'd compute from the raw DER cert.
fp := sha256.Sum256(cert.TLSCert.Certificate[0])
want := base64.StdEncoding.EncodeToString(fp[:])
if resp.Hash != want {
t.Fatalf("hash = %q, want %q", resp.Hash, want)
}
}

func TestHandleCertHashWithGetCertificate(t *testing.T) {
t.Parallel()

cert, err := certs.Generate(24 * 60 * 60 * 1e9)
if err != nil {
t.Fatalf("certs.Generate: %v", err)
}

tlsCfg := &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return &cert.TLSCert, nil
},
}
srv, err := NewServer(ServerConfig{
Addr: ":0",
TLSConfig: tlsCfg,
})
if err != nil {
t.Fatalf("NewServer: %v", err)
}
handler := srv.APIHandler()

req := httptest.NewRequest("GET", "/api/cert-hash", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}

var resp certHashResponse
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("decode: %v", err)
}

fp := sha256.Sum256(cert.TLSCert.Certificate[0])
want := base64.StdEncoding.EncodeToString(fp[:])
if resp.Hash != want {
t.Fatalf("hash = %q, want %q", resp.Hash, want)
}
}

func TestHandleStreamDebugNotFound(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -265,11 +349,11 @@ func TestNewServerValidation(t *testing.T) {
t.Fatalf("certs.Generate: %v", err)
}

t.Run("missing cert", func(t *testing.T) {
t.Run("missing cert and tls config", func(t *testing.T) {
t.Parallel()
_, err := NewServer(ServerConfig{Addr: ":4443"})
if err == nil {
t.Fatal("expected error for missing cert")
t.Fatal("expected error for missing cert and tls config")
}
})

Expand All @@ -281,7 +365,7 @@ func TestNewServerValidation(t *testing.T) {
}
})

t.Run("valid config", func(t *testing.T) {
t.Run("valid config with cert", func(t *testing.T) {
t.Parallel()
srv, err := NewServer(ServerConfig{Addr: ":4443", Cert: cert})
if err != nil {
Expand All @@ -291,6 +375,20 @@ func TestNewServerValidation(t *testing.T) {
t.Fatal("server is nil")
}
})

t.Run("valid config with tls config", func(t *testing.T) {
t.Parallel()
srv, err := NewServer(ServerConfig{
Addr: ":4443",
TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert.TLSCert}},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if srv == nil {
t.Fatal("server is nil")
}
})
}

func TestStreamLifecycleCallbacks(t *testing.T) {
Expand Down