diff --git a/internal/pdcp/writer.go b/internal/pdcp/writer.go index 3adf91b6..9fe8f219 100644 --- a/internal/pdcp/writer.go +++ b/internal/pdcp/writer.go @@ -52,6 +52,7 @@ type UploadWriter struct { assetGroupID string assetGroupName string counter atomic.Int32 + droppedCounter atomic.Int32 closed atomic.Bool TeamID string } @@ -65,7 +66,7 @@ func NewUploadWriterCallback(ctx context.Context, creds *pdcpauth.PDCPCredential u := &UploadWriter{ creds: creds, done: make(chan struct{}, 1), - data: make(chan *clients.Response, 8), // default buffer size + data: make(chan *clients.Response, 1000), // increased buffer size TeamID: "", } var err error @@ -91,7 +92,12 @@ func NewUploadWriterCallback(ctx context.Context, creds *pdcpauth.PDCPCredential // GetWriterCallback returns the writer callback func (u *UploadWriter) GetWriterCallback() func(*clients.Response) { return func(resp *clients.Response) { - u.data <- resp + select { + case u.data <- resp: + default: + u.droppedCounter.Add(1) + gologger.Warning().Msgf("PDCP upload buffer full, skipping result") + } } } @@ -125,6 +131,9 @@ func (u *UploadWriter) autoCommit(ctx context.Context) { } else { gologger.Info().Msgf("Found %v results, View found results in dashboard : %v", u.counter.Load(), getAssetsDashBoardURL(u.assetGroupID, u.TeamID)) } + if dropped := u.droppedCounter.Load(); dropped > 0 { + gologger.Warning().Msgf("Dropped %v results due to upload buffer overflow", dropped) + } }() // temporary buffer to store the results buff := &bytes.Buffer{} diff --git a/pkg/tlsx/tls/tls.go b/pkg/tlsx/tls/tls.go index c07a5ed2..c3bddd8e 100644 --- a/pkg/tlsx/tls/tls.go +++ b/pkg/tlsx/tls/tls.go @@ -236,10 +236,18 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con conn := tls.Client(baseConn, baseCfg) - if err := conn.Handshake(); err == nil { + ctx := context.Background() + var cancel context.CancelFunc + if c.options.Timeout != 0 { + ctx, cancel = context.WithTimeout(ctx, time.Duration(c.options.Timeout)*time.Second) + } + if err := conn.HandshakeContext(ctx); err == nil { ciphersuite := conn.ConnectionState().CipherSuite enumeratedCiphers = append(enumeratedCiphers, tls.CipherSuiteName(ciphersuite)) } + if cancel != nil { + cancel() + } _ = conn.Close() // close baseConn internally } return enumeratedCiphers, nil diff --git a/pkg/tlsx/ztls/ztls.go b/pkg/tlsx/ztls/ztls.go index a03b7267..be09a8e4 100644 --- a/pkg/tlsx/ztls/ztls.go +++ b/pkg/tlsx/ztls/ztls.go @@ -140,7 +140,7 @@ func (c *Client) ConnectWithOptions(hostname, ip, port string, options clients.C // new tls connection tlsConn := tls.Client(conn, config) - err = c.tlsHandshakeWithTimeout(tlsConn, ctx) + err = c.tlsHandshakeWithTimeout(ctx, tlsConn) if err != nil { if clients.IsClientCertRequiredError(err) { clientCertRequired = true @@ -257,10 +257,18 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con conn := tls.Client(baseConn, baseCfg) baseCfg.CipherSuites = []uint16{ztlsCiphers[v]} - if err := c.tlsHandshakeWithTimeout(conn, context.TODO()); err == nil { + ctx := context.Background() + var cancel context.CancelFunc + if c.options.Timeout != 0 { + ctx, cancel = context.WithTimeout(ctx, time.Duration(c.options.Timeout)*time.Second) + } + if err := c.tlsHandshakeWithTimeout(ctx, conn); err == nil { h1 := conn.GetHandshakeLog() enumeratedCiphers = append(enumeratedCiphers, h1.ServerHello.CipherSuite.String()) } + if cancel != nil { + cancel() + } _ = conn.Close() // also closes baseConn internally } return enumeratedCiphers, nil @@ -321,19 +329,20 @@ func (c *Client) getConfig(hostname, ip, port string, options clients.ConnectOpt } // tlsHandshakeWithCtx attempts tls handshake with given timeout -func (c *Client) tlsHandshakeWithTimeout(tlsConn *tls.Conn, ctx context.Context) error { +func (c *Client) tlsHandshakeWithTimeout(ctx context.Context, tlsConn *tls.Conn) error { errChan := make(chan error, 1) - defer close(errChan) + + go func() { + errChan <- tlsConn.Handshake() + }() select { case <-ctx.Done(): return errorutil.NewWithTag("ztls", "timeout while attempting handshake") //nolint - case errChan <- tlsConn.Handshake(): - } - - err := <-errChan - if err == tls.ErrCertsOnly { - err = nil + case err := <-errChan: + if err == tls.ErrCertsOnly { + err = nil + } + return err } - return err }