Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions pkg/fuzz/component/body.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,24 @@ func (b *Body) Parse(req *retryablehttp.Request) (bool, error) {

// parseBody parses a body with a custom decoder
func (b *Body) parseBody(decoderName string, req *retryablehttp.Request) (bool, error) {
decoder := dataformat.Get(decoderName)
var decoder dataformat.DataFormat
if decoderName == dataformat.MultiPartFormDataFormat {
// set content type to extract boundary
if err := decoder.(*dataformat.MultiPartForm).ParseBoundary(req.Header.Get("Content-Type")); err != nil {
// multipart has per-request state (boundary, file metadata) so we need
// a fresh instance to avoid concurrent writes to the global singleton
mpf := dataformat.NewMultiPartForm()
if err := mpf.ParseBoundary(req.Header.Get("Content-Type")); err != nil {
return false, errors.Wrap(err, "could not parse boundary")
}
decoder = mpf
} else {
decoder = dataformat.Get(decoderName)
}
decoded, err := decoder.Decode(b.value.String())
if err != nil {
return false, errors.Wrap(err, "could not decode raw")
}
b.value.SetParsed(decoded, decoder.Name())
b.value.encoder = decoder
return true, nil
}

Expand Down
71 changes: 71 additions & 0 deletions pkg/fuzz/component/body_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package component

import (
"bytes"
"fmt"
"io"
"mime/multipart"
"strings"
"sync"
"testing"

"github.com/projectdiscovery/retryablehttp-go"
Expand Down Expand Up @@ -173,3 +175,72 @@ func TestMultiPartFormComponent(t *testing.T) {
require.Contains(t, string(newBody), "username", "unexpected body content")
require.Contains(t, string(newBody), "testuser", "unexpected body content")
}

// each goroutine creates its own multipart body with a unique boundary
// and parses it concurrently. before the fix this would crash with
// "fatal error: concurrent map writes" because all goroutines shared
// the same MultiPartForm singleton.
func TestMultiPartFormConcurrentParse(t *testing.T) {
const goroutines = 20
var wg sync.WaitGroup
wg.Add(goroutines)

errs := make(chan error, goroutines)

for i := 0; i < goroutines; i++ {
go func(id int) {
defer wg.Done()

formData := &bytes.Buffer{}
writer := multipart.NewWriter(formData)
_ = writer.WriteField("field", fmt.Sprintf("value-%d", id))
contentType := writer.FormDataContentType()
_ = writer.Close()

req, err := retryablehttp.NewRequest("POST", "https://example.com", bytes.NewReader(formData.Bytes()))
if err != nil {
errs <- fmt.Errorf("goroutine %d: new request: %w", id, err)
return
}
req.Header.Set("Content-Type", contentType)

body := NewBody()
parsed, err := body.Parse(req)
if err != nil {
errs <- fmt.Errorf("goroutine %d: parse: %w", id, err)
return
}
if !parsed {
errs <- fmt.Errorf("goroutine %d: body was not parsed", id)
return
}

_ = body.SetValue("field", fmt.Sprintf("fuzzed-%d", id))

rebuilt, err := body.Rebuild()
if err != nil {
errs <- fmt.Errorf("goroutine %d: rebuild: %w", id, err)
return
}

rebuiltBody, err := io.ReadAll(rebuilt.Body)
if err != nil {
errs <- fmt.Errorf("goroutine %d: read rebuilt body: %w", id, err)
return
}

expected := fmt.Sprintf("fuzzed-%d", id)
if !strings.Contains(string(rebuiltBody), expected) {
errs <- fmt.Errorf("goroutine %d: rebuilt body missing %q", id, expected)
return
}
}(i)
}

wg.Wait()
close(errs)

for err := range errs {
t.Error(err)
}
}
13 changes: 11 additions & 2 deletions pkg/fuzz/component/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Value struct {
data string
parsed dataformat.KV
dataFormat string
encoder dataformat.DataFormat
}

// NewValue returns a new value component
Expand All @@ -42,6 +43,7 @@ func (v *Value) Clone() *Value {
data: v.data,
parsed: v.parsed.Clone(),
dataFormat: v.dataFormat,
encoder: v.encoder,
}
}

Expand Down Expand Up @@ -135,12 +137,19 @@ func (v *Value) Delete(key string) bool {

// Encode encodes the value into a string
// using the dataformat and encoding
func (v *Value) encode(data dataformat.KV) (string, error) {
if v.encoder != nil {
return v.encoder.Encode(data)
}
return dataformat.Encode(data, v.dataFormat)
}

func (v *Value) Encode() (string, error) {
toEncodeStr := v.data
if v.parsed.OrderedMap != nil {
// flattening orderedmap not supported
if v.dataFormat != "" {
dataformatStr, err := dataformat.Encode(v.parsed, v.dataFormat)
dataformatStr, err := v.encode(v.parsed)
if err != nil {
return "", err
}
Expand All @@ -154,7 +163,7 @@ func (v *Value) Encode() (string, error) {
return "", err
}
if v.dataFormat != "" {
dataformatStr, err := dataformat.Encode(dataformat.KVMap(nested), v.dataFormat)
dataformatStr, err := v.encode(dataformat.KVMap(nested))
if err != nil {
return "", err
}
Expand Down
Loading