diff --git a/contrib/net/http/internal/wrap/roundtrip.go b/contrib/net/http/internal/wrap/roundtrip.go index d1d25dc457..4842371d71 100644 --- a/contrib/net/http/internal/wrap/roundtrip.go +++ b/contrib/net/http/internal/wrap/roundtrip.go @@ -13,6 +13,7 @@ import ( "net/http/httptrace" "os" "strconv" + "sync" "time" "github.com/DataDog/dd-trace-go/contrib/net/http/v2/internal/config" @@ -28,15 +29,16 @@ import ( type AfterRoundTrip = func(*http.Response, error) (*http.Response, error) -// httpTraceTimings captures key timing events from httptrace.ClientTrace +// httpTraceTimings captures key timing events from httptrace.ClientTrace. type httpTraceTimings struct { - dnsStart, dnsEnd time.Time - connectStart, connectEnd time.Time - tlsStart, tlsEnd time.Time - getConnStart, gotConn time.Time - wroteHeaders, gotFirstByte time.Time - connectErr error - tlsErr error + mu sync.Mutex + dnsStart, dnsEnd time.Time // +checklocks:mu + connectStart, connectEnd time.Time // +checklocks:mu + tlsStart, tlsEnd time.Time // +checklocks:mu + getConnStart, gotConn time.Time // +checklocks:mu + wroteHeaders, gotFirstByte time.Time // +checklocks:mu + connectErr error // +checklocks:mu + tlsErr error // +checklocks:mu } // addDurationTag adds a timing tag to the span if both timestamps are valid @@ -49,6 +51,9 @@ func (t *httpTraceTimings) addDurationTag(span *tracer.Span, tagName string, sta // addTimingTags adds all timing information to the span func (t *httpTraceTimings) addTimingTags(span *tracer.Span) { + t.mu.Lock() + defer t.mu.Unlock() + t.addDurationTag(span, "http.dns.duration_ms", t.dnsStart, t.dnsEnd) t.addDurationTag(span, "http.connect.duration_ms", t.connectStart, t.connectEnd) t.addDurationTag(span, "http.tls.duration_ms", t.tlsStart, t.tlsEnd) @@ -67,16 +72,58 @@ func (t *httpTraceTimings) addTimingTags(span *tracer.Span) { // newClientTrace creates a ClientTrace that captures timing information func newClientTrace(timings *httpTraceTimings) *httptrace.ClientTrace { return &httptrace.ClientTrace{ - DNSStart: func(httptrace.DNSStartInfo) { timings.dnsStart = time.Now() }, - DNSDone: func(httptrace.DNSDoneInfo) { timings.dnsEnd = time.Now() }, - ConnectStart: func(network, addr string) { timings.connectStart = time.Now() }, - ConnectDone: func(network, addr string, err error) { timings.connectEnd = time.Now(); timings.connectErr = err }, - TLSHandshakeStart: func() { timings.tlsStart = time.Now() }, - TLSHandshakeDone: func(_ tls.ConnectionState, err error) { timings.tlsEnd = time.Now(); timings.tlsErr = err }, - GetConn: func(hostPort string) { timings.getConnStart = time.Now() }, - GotConn: func(httptrace.GotConnInfo) { timings.gotConn = time.Now() }, - WroteHeaders: func() { timings.wroteHeaders = time.Now() }, - GotFirstResponseByte: func() { timings.gotFirstByte = time.Now() }, + DNSStart: func(httptrace.DNSStartInfo) { + timings.mu.Lock() + timings.dnsStart = time.Now() + timings.mu.Unlock() + }, + DNSDone: func(httptrace.DNSDoneInfo) { + timings.mu.Lock() + timings.dnsEnd = time.Now() + timings.mu.Unlock() + }, + ConnectStart: func(network, addr string) { + timings.mu.Lock() + timings.connectStart = time.Now() + timings.mu.Unlock() + }, + ConnectDone: func(network, addr string, err error) { + timings.mu.Lock() + timings.connectEnd = time.Now() + timings.connectErr = err + timings.mu.Unlock() + }, + TLSHandshakeStart: func() { + timings.mu.Lock() + timings.tlsStart = time.Now() + timings.mu.Unlock() + }, + TLSHandshakeDone: func(_ tls.ConnectionState, err error) { + timings.mu.Lock() + timings.tlsEnd = time.Now() + timings.tlsErr = err + timings.mu.Unlock() + }, + GetConn: func(hostPort string) { + timings.mu.Lock() + timings.getConnStart = time.Now() + timings.mu.Unlock() + }, + GotConn: func(httptrace.GotConnInfo) { + timings.mu.Lock() + timings.gotConn = time.Now() + timings.mu.Unlock() + }, + WroteHeaders: func() { + timings.mu.Lock() + timings.wroteHeaders = time.Now() + timings.mu.Unlock() + }, + GotFirstResponseByte: func() { + timings.mu.Lock() + timings.gotFirstByte = time.Now() + timings.mu.Unlock() + }, } } diff --git a/contrib/net/http/roundtripper_test.go b/contrib/net/http/roundtripper_test.go index 584f64c5e5..99a791d3d7 100644 --- a/contrib/net/http/roundtripper_test.go +++ b/contrib/net/http/roundtripper_test.go @@ -15,6 +15,7 @@ import ( "regexp" "strconv" "strings" + "sync" "testing" "time" @@ -637,6 +638,39 @@ func TestClientTimings(t *testing.T) { }) } +func TestClientTimingsRace(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + rt := WrapRoundTripper(http.DefaultTransport, WithClientTimings(true)) + client := &http.Client{Transport: rt} + + const numGoroutines = 10 + const numReqs = 10 + + var wg sync.WaitGroup + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < numReqs; j++ { + resp, err := client.Get(srv.URL) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + resp.Body.Close() + } + }() + } + wg.Wait() +} + func TestClientQueryStringCollected(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Write([]byte("Hello World"))