Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions pkg/network/tracer/testutil/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ func (t *TCPServer) Addr() net.Addr {
return t.ln.Addr()
}

// Run starts the TCP server
func (t *TCPServer) Run() error {
// Listen sets up the socket with net.Listen
func (t *TCPServer) Listen() error {
networkType := "tcp"
if t.Network != "" {
networkType = t.Network
Expand All @@ -59,11 +59,16 @@ func (t *TCPServer) Run() error {
return err
}
t.ln = ln

t.address = ln.Addr().String()
return nil
}

// StartAccepting starts up the server's Accept goroutine
func (t *TCPServer) StartAccepting() {
go func() {
for {
conn, err := ln.Accept()
conn, err := t.ln.Accept()
if err != nil {
return
}
Expand All @@ -74,7 +79,16 @@ func (t *TCPServer) Run() error {
go t.onMessage(conn)
}
}()
}

// Run starts the TCP server
func (t *TCPServer) Run() error {
err := t.Listen()
if err != nil {
return err
}

t.StartAccepting()
return nil
}

Expand All @@ -101,7 +115,7 @@ func DialTCP(network, address string) (net.Conn, error) {
func (t *TCPServer) Shutdown() {
if t.ln != nil {
_ = t.ln.Close()
t.ln = nil
// do not set t.ln = nil here, because otherwise t.ln.Accept() can panic later
}
}

Expand Down
152 changes: 12 additions & 140 deletions pkg/networkpath/traceroute/common/traceroute_parallel.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ package common

import (
"context"
"errors"
"fmt"
"net/netip"
"slices"
"sync"
"time"

Expand All @@ -19,78 +16,9 @@ import (
"github.com/DataDog/datadog-agent/pkg/util/log"
)

// ReceiveProbeNoPktError is returned when ReceiveProbe() didn't find anything new.
// This is normal if the RTT is long
type ReceiveProbeNoPktError struct {
Err error
}

func (e *ReceiveProbeNoPktError) Error() string {
return fmt.Sprintf("ReceiveProbe() didn't find any new packets: %s", e.Err)
}
func (e *ReceiveProbeNoPktError) Unwrap() error {
return e.Err
}

// BadPacketError is a non-fatal error that occurs when a packet is malformed.
type BadPacketError struct {
Err error
}

func (e *BadPacketError) Error() string {
return fmt.Sprintf("Failed to parse packet: %s", e.Err)
}
func (e *BadPacketError) Unwrap() error {
return e.Err
}

// ProbeResponse is the response of a single probe in a traceroute
type ProbeResponse struct {
// TTL is the Time To Live of the probe that was originally sent
TTL uint8
// IP is the IP address of the responding host
IP netip.Addr
// RTT is the round-trip time of the probe
RTT time.Duration
// IsDest is true if the responding host is the destination
IsDest bool
}

// TracerouteDriverInfo is metadata about a TracerouteDriver
type TracerouteDriverInfo struct {
SupportsParallel bool
}

// TracerouteDriver is an implementation of traceroute send+receive of packets
type TracerouteDriver interface {
// GetDriverInfo returns metadata about this driver
GetDriverInfo() TracerouteDriverInfo
// SendProbe sends a traceroute packet with a specific TTL
SendProbe(ttl uint8) error
// ReceiveProbe polls to get a traceroute response with a timeout.
ReceiveProbe(timeout time.Duration) (*ProbeResponse, error)
}

// TracerouteParallelParams are the parameters for TracerouteParallel
type TracerouteParallelParams struct {
// MinTTL is the TTL to start the traceroute at
MinTTL uint8
// MaxTTL is the TTL to end the traceroute at
MaxTTL uint8
// TracerouteTimeout is the maximum time to wait for a response
TracerouteTimeout time.Duration
// PollFrequency is how often to poll for a response
PollFrequency time.Duration
// SendDelay is the delay between sending probes (typically small)
SendDelay time.Duration
}

// ProbeCount returns the number of probes that will be sent
func (p TracerouteParallelParams) ProbeCount() int {
if p.MinTTL > p.MaxTTL {
return 0
}
return int(p.MaxTTL) - int(p.MinTTL) + 1
TracerouteParams
}

// MaxTimeout combines the timeout+probe delays into a total timeout for the traceroute
Expand All @@ -101,12 +29,10 @@ func (p TracerouteParallelParams) MaxTimeout() time.Duration {

// TracerouteParallel runs a traceroute in parallel
func TracerouteParallel(ctx context.Context, t TracerouteDriver, p TracerouteParallelParams) ([]*ProbeResponse, error) {
if p.MinTTL > p.MaxTTL {
return nil, fmt.Errorf("min TTL must be less than or equal to max TTL")
}
if p.MinTTL < 1 {
return nil, fmt.Errorf("min TTL must be at least 1")
if err := p.validate(); err != nil {
return nil, err
}

info := t.GetDriverInfo()
if !info.SupportsParallel {
return nil, fmt.Errorf("tried to call TracerouteParallel on a TracerouteDriver that doesn't support parallel")
Expand Down Expand Up @@ -144,15 +70,13 @@ func TracerouteParallel(ctx context.Context, t TracerouteDriver, p TraceroutePar
g.Go(func() error {
for i := int(p.MinTTL); i <= int(p.MaxTTL); i++ {
// leave if we got cancelled
select {
case <-writerCtx.Done():
if writerCtx.Err() != nil {
return nil
default:
}

err := t.SendProbe(uint8(i))
if err != nil {
return err
return fmt.Errorf("SendProbe() failed: %w", err)
}

time.Sleep(p.SendDelay)
Expand All @@ -163,28 +87,19 @@ func TracerouteParallel(ctx context.Context, t TracerouteDriver, p TraceroutePar
g.Go(func() error {
for {
// leave if we got cancelled, SendProbe() failed, etc
select {
// doesn't use writerCtx because even if we writerCancel(), we want to keep reading
case <-groupCtx.Done():
// doesn't use writerCtx because when we find the destination, we writerCancel(), and we want to keep reading
if groupCtx.Err() != nil {
return nil
default:
}

probe, err := t.ReceiveProbe(p.PollFrequency)
if CheckParallelRetryable("ReceiveProbe", err) {
if CheckProbeRetryable("ReceiveProbe", err) {
continue
} else if err != nil {
return fmt.Errorf("ReceiveProbe() failed: %w", err)
} else if err = p.validateProbe(probe); err != nil {
return err
}
if probe == nil {
return fmt.Errorf("ReceiveProbe() returned nil without an error (this indicates a bug in the TracerouteDriver)")
}
if probe.TTL == 0 {
return fmt.Errorf("ReceiveProbe() got TTL 0 which is only allowed for TracerouteSerial (this indicates a bug in the TracerouteDriver)")
}
if probe.TTL < p.MinTTL || probe.TTL > p.MaxTTL {
return fmt.Errorf("ReceiveProbe() received an invalid TTL: expected TTL in [%d, %d], got %d", p.MinTTL, p.MaxTTL, probe.TTL)
}

writeProbe(probe)
// no need to send more probes if we found the destination
Expand All @@ -205,48 +120,5 @@ func TracerouteParallel(ctx context.Context, t TracerouteDriver, p TraceroutePar
return nil, ctx.Err()
}

destIdx := slices.IndexFunc(results, func(pr *ProbeResponse) bool {
return pr != nil && pr.IsDest
})
// trim off anything after the destination
if destIdx != -1 {
results = slices.Clip(results[:destIdx+1])
}

return results[p.MinTTL:], nil
}

// ToHops converts a list of ProbeResponses to a Results
// TODO remove this, and use a single type to represent results
func ToHops(p TracerouteParallelParams, probes []*ProbeResponse) ([]*Hop, error) {
if p.MinTTL != 1 {
return nil, fmt.Errorf("ToHops: processResults() requires MinTTL == 1")
}
hops := make([]*Hop, len(probes))
for i, probe := range probes {
hops[i] = &Hop{}
if probe != nil {
hops[i].IP = probe.IP.AsSlice()
hops[i].RTT = probe.RTT
hops[i].IsDest = probe.IsDest
}
}
return hops, nil
}

var badPktLimit = log.NewLogLimit(10, 5*time.Minute)

// CheckParallelRetryable returns whether ReceiveProbe failed due to a real error or just an irrelevant packet
func CheckParallelRetryable(funcName string, err error) bool {
noPktErr := &ReceiveProbeNoPktError{}
badPktErr := &BadPacketError{}
if errors.As(err, &noPktErr) {
return true
} else if errors.As(err, &badPktErr) {
if badPktLimit.ShouldLog() {
log.Warnf("%s() saw a malformed packet: %s", funcName, err)
}
return true
}
return false
return clipResults(p.MinTTL, results), nil
}
Loading