From d9a3c5956da184e3287f43818d8a86d40c8fc443 Mon Sep 17 00:00:00 2001 From: Thiago Pontes Date: Wed, 22 Apr 2026 14:58:48 -0400 Subject: [PATCH] distribution: Add option to pass complete *tls.Config to server --- distribution/server.go | 48 +++++++++++++++-- distribution/server_test.go | 104 ++++++++++++++++++++++++++++++++++-- 2 files changed, 144 insertions(+), 8 deletions(-) diff --git a/distribution/server.go b/distribution/server.go index cd2b7d4..88c19da 100644 --- a/distribution/server.go +++ b/distribution/server.go @@ -2,7 +2,9 @@ package distribution import ( "context" + "crypto/sha256" "crypto/tls" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -126,6 +128,7 @@ type ServerConfig struct { Addr string WebDir string Cert *certs.CertInfo + TLSConfig *tls.Config StreamLister StreamLister IngestLookup IngestLookup SRTPull SRTPullFunc @@ -180,8 +183,8 @@ type Server struct { // NewServer creates a distribution Server with the given configuration. // It returns an error if required fields are missing. func NewServer(config ServerConfig) (*Server, error) { - if config.Cert == nil { - return nil, errors.New("distribution: Cert is required") + if config.Cert == nil && config.TLSConfig == nil { + return nil, errors.New("distribution: either Cert or TLSConfig is required") } if config.Addr == "" { return nil, errors.New("distribution: Addr is required") @@ -327,8 +330,13 @@ func (s *Server) Start(ctx context.Context) error { wtMux.HandleFunc("/moq", s.handleMoQ) s.registerAPIRoutes(wtMux) - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{s.config.Cert.TLSCert}, + var tlsConfig *tls.Config + if s.config.TLSConfig != nil { + tlsConfig = s.config.TLSConfig + } else { + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{s.config.Cert.TLSCert}, + } } s.wtSrv = &webtransport.Server{ @@ -518,12 +526,42 @@ func (s *Server) handleStreamDebug(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleCertHash(w http.ResponseWriter, _ *http.Request) { + var hash string + if s.config.Cert != nil { + hash = s.config.Cert.FingerprintBase64() + } else { + cert, err := s.tlsCertificate() + if err != nil { + slog.Error("failed to get TLS certificate", "error", err) + writeError(w, http.StatusInternalServerError, "no TLS certificate available") + return + } + fp := sha256.Sum256(cert.Certificate[0]) + hash = base64.StdEncoding.EncodeToString(fp[:]) + } writeJSON(w, http.StatusOK, certHashResponse{ - Hash: s.config.Cert.FingerprintBase64(), + Hash: hash, Addr: s.config.Addr, }) } +func (s *Server) tlsCertificate() (*tls.Certificate, error) { + tc := s.config.TLSConfig + if tc.GetCertificate != nil { + cert, err := tc.GetCertificate(&tls.ClientHelloInfo{}) + if err != nil { + return nil, err + } + if cert != nil { + return cert, nil + } + } + if len(tc.Certificates) > 0 { + return &tc.Certificates[0], nil + } + return nil, errors.New("no certificate in TLSConfig") +} + func (s *Server) handleSRTPullOptions(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type") diff --git a/distribution/server_test.go b/distribution/server_test.go index b5469d7..b859664 100644 --- a/distribution/server_test.go +++ b/distribution/server_test.go @@ -1,6 +1,9 @@ package distribution import ( + "crypto/sha256" + "crypto/tls" + "encoding/base64" "encoding/json" "net/http" "net/http/httptest" @@ -111,6 +114,87 @@ func TestHandleCertHash(t *testing.T) { } } +func TestHandleCertHashWithTLSConfig(t *testing.T) { + t.Parallel() + + cert, err := certs.Generate(24 * 60 * 60 * 1e9) + if err != nil { + t.Fatalf("certs.Generate: %v", err) + } + + tlsCfg := &tls.Config{Certificates: []tls.Certificate{cert.TLSCert}} + srv, err := NewServer(ServerConfig{ + Addr: ":0", + TLSConfig: tlsCfg, + }) + if err != nil { + t.Fatalf("NewServer: %v", err) + } + handler := srv.APIHandler() + + req := httptest.NewRequest("GET", "/api/cert-hash", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp certHashResponse + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + + // Verify the hash matches what we'd compute from the raw DER cert. + fp := sha256.Sum256(cert.TLSCert.Certificate[0]) + want := base64.StdEncoding.EncodeToString(fp[:]) + if resp.Hash != want { + t.Fatalf("hash = %q, want %q", resp.Hash, want) + } +} + +func TestHandleCertHashWithGetCertificate(t *testing.T) { + t.Parallel() + + cert, err := certs.Generate(24 * 60 * 60 * 1e9) + if err != nil { + t.Fatalf("certs.Generate: %v", err) + } + + tlsCfg := &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return &cert.TLSCert, nil + }, + } + srv, err := NewServer(ServerConfig{ + Addr: ":0", + TLSConfig: tlsCfg, + }) + if err != nil { + t.Fatalf("NewServer: %v", err) + } + handler := srv.APIHandler() + + req := httptest.NewRequest("GET", "/api/cert-hash", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp certHashResponse + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + + fp := sha256.Sum256(cert.TLSCert.Certificate[0]) + want := base64.StdEncoding.EncodeToString(fp[:]) + if resp.Hash != want { + t.Fatalf("hash = %q, want %q", resp.Hash, want) + } +} + func TestHandleStreamDebugNotFound(t *testing.T) { t.Parallel() @@ -265,11 +349,11 @@ func TestNewServerValidation(t *testing.T) { t.Fatalf("certs.Generate: %v", err) } - t.Run("missing cert", func(t *testing.T) { + t.Run("missing cert and tls config", func(t *testing.T) { t.Parallel() _, err := NewServer(ServerConfig{Addr: ":4443"}) if err == nil { - t.Fatal("expected error for missing cert") + t.Fatal("expected error for missing cert and tls config") } }) @@ -281,7 +365,7 @@ func TestNewServerValidation(t *testing.T) { } }) - t.Run("valid config", func(t *testing.T) { + t.Run("valid config with cert", func(t *testing.T) { t.Parallel() srv, err := NewServer(ServerConfig{Addr: ":4443", Cert: cert}) if err != nil { @@ -291,6 +375,20 @@ func TestNewServerValidation(t *testing.T) { t.Fatal("server is nil") } }) + + t.Run("valid config with tls config", func(t *testing.T) { + t.Parallel() + srv, err := NewServer(ServerConfig{ + Addr: ":4443", + TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert.TLSCert}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if srv == nil { + t.Fatal("server is nil") + } + }) } func TestStreamLifecycleCallbacks(t *testing.T) {