diff --git a/internal/api/handlers/v0/list_errors.go b/internal/api/handlers/v0/list_errors.go index b83f04e5a..12cbd9bfc 100644 --- a/internal/api/handlers/v0/list_errors.go +++ b/internal/api/handlers/v0/list_errors.go @@ -8,13 +8,20 @@ import ( "github.com/danielgtaylor/huma/v2" ) +func clientClosedRequest(ctx context.Context, err error) (error, bool) { + if errors.Is(err, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) { + return huma.NewError(499, "Client closed request", err), true + } + return nil, false +} + // ListServersError maps ListServers failures; client disconnects must not log as 500s. func ListServersError(ctx context.Context, err error) error { if err == nil { return nil } - if errors.Is(err, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) { - return huma.NewError(499, "Client closed request", err) + if cerr, ok := clientClosedRequest(ctx, err); ok { + return cerr } log.Printf("list servers failed: %v", err) // Do not pass err here: huma serializes extra error args into the response @@ -22,3 +29,27 @@ func ListServersError(ctx context.Context, err error) error { // server-side only, like the sibling handlers in servers.go. return huma.Error500InternalServerError("Failed to get registry list") } + +// GetServerDetailsError maps get-server-version failures; client disconnects must not log as 500s. +func GetServerDetailsError(ctx context.Context, err error, serverName, version string) error { + if err == nil { + return nil + } + if cerr, ok := clientClosedRequest(ctx, err); ok { + return cerr + } + log.Printf("get server details (%q/%q) failed: %v", serverName, version, err) + return huma.Error500InternalServerError("Failed to get server details") +} + +// GetServerVersionsError maps get-server-versions failures; client disconnects must not log as 500s. +func GetServerVersionsError(ctx context.Context, err error, serverName string) error { + if err == nil { + return nil + } + if cerr, ok := clientClosedRequest(ctx, err); ok { + return cerr + } + log.Printf("get server versions (%q) failed: %v", serverName, err) + return huma.Error500InternalServerError("Failed to get server versions") +} diff --git a/internal/api/handlers/v0/list_errors_test.go b/internal/api/handlers/v0/list_errors_test.go index 7c50d382a..3519134f8 100644 --- a/internal/api/handlers/v0/list_errors_test.go +++ b/internal/api/handlers/v0/list_errors_test.go @@ -43,3 +43,41 @@ func TestListServersError_realFailureDoesNotLeakDetail(t *testing.T) { assert.NotContains(t, string(body), "database unavailable") assert.NotContains(t, string(body), "internal-host") } + +func TestGetServerDetailsError_clientCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := v0.GetServerDetailsError(ctx, errors.New("query canceled"), "io.github.acme/widget", "1.0.0") + require.Error(t, err) + var se huma.StatusError + require.True(t, errors.As(err, &se)) + assert.Equal(t, 499, se.GetStatus()) +} + +func TestGetServerDetailsError_realFailure(t *testing.T) { + err := v0.GetServerDetailsError(context.Background(), errors.New("database unavailable"), "io.github.acme/widget", "1.0.0") + require.Error(t, err) + var se huma.StatusError + require.True(t, errors.As(err, &se)) + assert.Equal(t, 500, se.GetStatus()) +} + +func TestGetServerVersionsError_clientCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := v0.GetServerVersionsError(ctx, errors.New("query canceled"), "io.github.acme/widget") + require.Error(t, err) + var se huma.StatusError + require.True(t, errors.As(err, &se)) + assert.Equal(t, 499, se.GetStatus()) +} + +func TestGetServerVersionsError_realFailure(t *testing.T) { + err := v0.GetServerVersionsError(context.Background(), errors.New("database unavailable"), "io.github.acme/widget") + require.Error(t, err) + var se huma.StatusError + require.True(t, errors.As(err, &se)) + assert.Equal(t, 500, se.GetStatus()) +} diff --git a/internal/api/handlers/v0/servers.go b/internal/api/handlers/v0/servers.go index 059095d28..448c8845e 100644 --- a/internal/api/handlers/v0/servers.go +++ b/internal/api/handlers/v0/servers.go @@ -3,7 +3,6 @@ package v0 import ( "context" "errors" - "log" "net/http" "net/url" "reflect" @@ -185,8 +184,7 @@ func RegisterServersEndpoints(api huma.API, pathPrefix string, registry service. if err.Error() == errRecordNotFound || errors.Is(err, database.ErrNotFound) { return nil, huma.Error404NotFound("Server not found") } - log.Printf("get server details (%q/%q) failed: %v", serverName, version, err) - return nil, huma.Error500InternalServerError("Failed to get server details") + return nil, GetServerDetailsError(ctx, err, serverName, version) } return &Response[apiv0.ServerResponse]{ @@ -215,8 +213,7 @@ func RegisterServersEndpoints(api huma.API, pathPrefix string, registry service. if err.Error() == errRecordNotFound || errors.Is(err, database.ErrNotFound) { return nil, huma.Error404NotFound("Server not found") } - log.Printf("get server versions (%q) failed: %v", serverName, err) - return nil, huma.Error500InternalServerError("Failed to get server versions") + return nil, GetServerVersionsError(ctx, err, serverName) } // Convert []*ServerResponse to []ServerResponse