diff --git a/bom.go b/bom.go index 687aaad..7dd0fa6 100644 --- a/bom.go +++ b/bom.go @@ -54,13 +54,19 @@ func (bs BOMService) ExportComponent(ctx context.Context, componentUUID uuid.UUI params["format"] = string(format) } - req, err := bs.client.newRequest(ctx, http.MethodGet, fmt.Sprintf("api/v1/bom/cyclonedx/component/%s", componentUUID), withParams(params)) + var acceptContentType string + switch format { + case BOMFormatJSON: + acceptContentType = "application/vnd.cyclonedx+json" + case BOMFormatXML: + acceptContentType = "application/vnd.cyclonedx+xml" + } + + req, err := bs.client.newRequest(ctx, http.MethodGet, fmt.Sprintf("api/v1/bom/cyclonedx/component/%s", componentUUID), withParams(params), withAcceptContentType(acceptContentType)) if err != nil { return } - req.Header.Set("Accept", "application/vnd.cyclonedx+json") - _, err = bs.client.doRequest(req, &bom) return } @@ -74,13 +80,19 @@ func (bs BOMService) ExportProject(ctx context.Context, projectUUID uuid.UUID, f params["variant"] = string(variant) } - req, err := bs.client.newRequest(ctx, http.MethodGet, fmt.Sprintf("api/v1/bom/cyclonedx/project/%s", projectUUID), withParams(params)) + var acceptContentType string + switch format { + case BOMFormatJSON: + acceptContentType = "application/vnd.cyclonedx+json" + case BOMFormatXML: + acceptContentType = "application/vnd.cyclonedx+xml" + } + + req, err := bs.client.newRequest(ctx, http.MethodGet, fmt.Sprintf("api/v1/bom/cyclonedx/project/%s", projectUUID), withParams(params), withAcceptContentType(acceptContentType)) if err != nil { return } - req.Header.Set("Accept", "application/vnd.cyclonedx+json") - _, err = bs.client.doRequest(req, &bom) return } diff --git a/client.go b/client.go index 7e6a346..3a12121 100644 --- a/client.go +++ b/client.go @@ -350,8 +350,15 @@ func (c Client) doRequest(req *http.Request, v interface{}) (a apiResponse, err } if v != nil { + contentType := res.Header.Get("Content-Type") + contentType = strings.SplitN(contentType, ";", 2)[0] switch vt := v.(type) { case *string: + expectedContentTypes := []string{"text/plain", "application/vnd.cyclonedx+json", "application/vnd.cyclonedx+xml"} + if !sliceContains(expectedContentTypes, contentType) { + err = fmt.Errorf("expected %s content-type, but received %s", strings.Join(expectedContentTypes, ", "), contentType) + return + } if content, readErr := io.ReadAll(res.Body); readErr == nil { *vt = strings.TrimSpace(string(content)) } else { @@ -359,6 +366,10 @@ func (c Client) doRequest(req *http.Request, v interface{}) (a apiResponse, err return } default: + if contentType != "application/json" { + err = fmt.Errorf("expected application/json content-type, but received %s", contentType) + return + } err = json.NewDecoder(res.Body).Decode(v) if err != nil { return diff --git a/user.go b/user.go index d9ee07b..142f2a4 100644 --- a/user.go +++ b/user.go @@ -50,13 +50,11 @@ func (us UserService) Login(ctx context.Context, username, password string) (tok body.Set("username", username) body.Set("password", password) - req, err := us.client.newRequest(ctx, http.MethodPost, "api/v1/user/login", withBody(body)) + req, err := us.client.newRequest(ctx, http.MethodPost, "api/v1/user/login", withBody(body), withAcceptContentType("text/plain")) if err != nil { return } - req.Header.Set("Accept", "*/*") - _, err = us.client.doRequest(req, &token) return } @@ -73,13 +71,11 @@ func (us UserService) ForceChangePassword(ctx context.Context, username, passwor body.Set("newPassword", newPassword) body.Set("confirmPassword", newPassword) - req, err := us.client.newRequest(ctx, http.MethodPost, "api/v1/user/forceChangePassword", withBody(body)) + req, err := us.client.newRequest(ctx, http.MethodPost, "api/v1/user/forceChangePassword", withBody(body), withAcceptContentType("text/plain")) if err != nil { return } - req.Header.Set("Accept", "*/*") - _, err = us.client.doRequest(req, nil) return } diff --git a/util.go b/util.go index 50afb77..6576038 100644 --- a/util.go +++ b/util.go @@ -56,3 +56,13 @@ func OptionalBoolOf(value bool) *bool { func OptionalBool() *bool { return nil } + +func sliceContains[S ~[]E, E comparable](haystack S, needle E) bool { + for _, v := range haystack { + if v == needle { + return true + } + } + return false + +} diff --git a/vex.go b/vex.go index fa572e0..fe97ad3 100644 --- a/vex.go +++ b/vex.go @@ -26,13 +26,11 @@ type vexUploadResponse struct { type VEXUploadToken string func (vs VEXService) ExportCycloneDX(ctx context.Context, projectUUID uuid.UUID) (vex string, err error) { - req, err := vs.client.newRequest(ctx, http.MethodGet, fmt.Sprintf("api/v1/vex/cyclonedx/project/%s", projectUUID)) + req, err := vs.client.newRequest(ctx, http.MethodGet, fmt.Sprintf("api/v1/vex/cyclonedx/project/%s", projectUUID), withAcceptContentType("application/vnd.cyclonedx+json")) if err != nil { return } - req.Header.Set("Accept", "application/vnd.cyclonedx+json") - _, err = vs.client.doRequest(req, &vex) return }