Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions pkg/output/file_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,21 @@ package output
import (
"bufio"
"os"
"sync/atomic"
)

const (
// flushThreshold is the number of writes after which we flush to disk.
// This ensures data is persisted periodically to minimize data loss
// if the process hangs or crashes.
flushThreshold = 100
)

// fileWriter is a concurrent file based output writer.
type fileWriter struct {
file *os.File
writer *bufio.Writer
file *os.File
writer *bufio.Writer
writeCount atomic.Int64
}

// NewFileOutputWriter creates a new buffered writer for a file
Expand All @@ -27,7 +36,18 @@ func (w *fileWriter) Write(data []byte) error {
return err
}
_, err = w.writer.WriteRune('\n')
return err
if err != nil {
return err
}

// Periodic flush to minimize data loss on hangs/crashes
count := w.writeCount.Add(1)
if count%flushThreshold == 0 {
if flushErr := w.writer.Flush(); flushErr != nil {
return flushErr
}
}
return nil
}

// Close closes the underlying writer flushing everything to disk
Expand Down
77 changes: 77 additions & 0 deletions pkg/output/file_writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package output

import (
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestFileWriterPeriodicFlush(t *testing.T) {
// Create temp file
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test_output.jsonl")

// Create writer
writer, err := newFileOutputWriter(tmpFile)
require.NoError(t, err)
defer writer.Close()

// Write exactly flushThreshold entries
for i := 0; i < flushThreshold; i++ {
err := writer.Write([]byte(`{"test": "data"}`))
require.NoError(t, err)
}

// After flushThreshold writes, data should be flushed to disk
// Read file contents without closing writer
contents, err := os.ReadFile(tmpFile)
require.NoError(t, err)

// File should have content (periodic flush worked)
assert.NotEmpty(t, contents, "file should have content after %d writes due to periodic flush", flushThreshold)
}

func TestFileWriterWriteCount(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test_count.jsonl")

writer, err := newFileOutputWriter(tmpFile)
require.NoError(t, err)
defer writer.Close()

// Write some entries
for i := 0; i < 50; i++ {
err := writer.Write([]byte(`{"count": 1}`))
require.NoError(t, err)
}

// Verify write count is tracked
count := writer.writeCount.Load()
assert.Equal(t, int64(50), count, "write count should be 50")
}

func TestFileWriterCloseFlushes(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test_close.jsonl")

writer, err := newFileOutputWriter(tmpFile)
require.NoError(t, err)

// Write less than flushThreshold (so periodic flush won't trigger)
for i := 0; i < flushThreshold/2; i++ {
err := writer.Write([]byte(`{"close": "test"}`))
require.NoError(t, err)
}

// Close should flush remaining data
err = writer.Close()
require.NoError(t, err)

// Verify all data was written
contents, err := os.ReadFile(tmpFile)
require.NoError(t, err)
assert.NotEmpty(t, contents, "file should have content after close")
}
13 changes: 12 additions & 1 deletion pkg/tlsx/tls/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,21 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con

conn := tls.Client(baseConn, baseCfg)

if err := conn.Handshake(); err == nil {
// Create context with timeout for cipher enumeration handshake
ctx := context.Background()
var cancel context.CancelFunc
if c.options.Timeout != 0 {
ctx, cancel = context.WithTimeout(ctx, time.Duration(c.options.Timeout)*time.Second)
}

if err := conn.HandshakeContext(ctx); err == nil {
ciphersuite := conn.ConnectionState().CipherSuite
enumeratedCiphers = append(enumeratedCiphers, tls.CipherSuiteName(ciphersuite))
}
// Cancel context per-iteration to release timer resources immediately
if cancel != nil {
cancel()
}
_ = conn.Close() // close baseConn internally
}
return enumeratedCiphers, nil
Expand Down
91 changes: 91 additions & 0 deletions pkg/tlsx/ztls/timeout_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package ztls

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

// TestHandshakeTimeoutCancellation verifies that the handshake timeout
// properly cancels when the context is cancelled, rather than blocking
// indefinitely on the handshake operation.
func TestHandshakeTimeoutCancellation(t *testing.T) {
// Create a very short timeout context
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()

// Verify context cancellation works as expected
select {
case <-ctx.Done():
// Expected - context should timeout
assert.Error(t, ctx.Err(), "context should have error after timeout")
case <-time.After(100 * time.Millisecond):
t.Fatal("context timeout did not trigger")
}
}

// TestContextSelectBehavior verifies that a goroutine-based approach
// allows the select statement to properly choose between completion
// and context cancellation.
func TestContextSelectBehavior(t *testing.T) {
// This test demonstrates the correct pattern for timeout-based
// handshakes: running the blocking operation in a goroutine
// so the select can properly evaluate both cases.

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

resultChan := make(chan string, 1)

// Simulate a slow operation in a goroutine
go func() {
time.Sleep(200 * time.Millisecond) // Slower than timeout
resultChan <- "completed"
}()

select {
case <-ctx.Done():
// This is the expected path - timeout should win
assert.Equal(t, context.DeadlineExceeded, ctx.Err())
case result := <-resultChan:
t.Fatalf("should have timed out, but got result: %s", result)
}
}

// TestNoDeadlockOnTimeout ensures that the timeout mechanism doesn't
// cause goroutine leaks or deadlocks.
func TestNoDeadlockOnTimeout(t *testing.T) {
done := make(chan struct{})

go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()

errChan := make(chan error, 1)

// Run blocking operation in goroutine
go func() {
time.Sleep(100 * time.Millisecond) // Simulate slow handshake
errChan <- nil
}()

select {
case <-ctx.Done():
// Timeout occurred - this is expected
case <-errChan:
// Operation completed
}

close(done)
}()

// Test should complete without deadlock
select {
case <-done:
// Success - no deadlock
case <-time.After(500 * time.Millisecond):
t.Fatal("test timed out - possible deadlock")
}
}
38 changes: 28 additions & 10 deletions pkg/tlsx/ztls/ztls.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,21 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con
conn := tls.Client(baseConn, baseCfg)
baseCfg.CipherSuites = []uint16{ztlsCiphers[v]}

if err := c.tlsHandshakeWithTimeout(conn, context.TODO()); err == nil {
// Create context with timeout for cipher enumeration handshake
ctx := context.Background()
var cancel context.CancelFunc
if c.options.Timeout != 0 {
ctx, cancel = context.WithTimeout(ctx, time.Duration(c.options.Timeout)*time.Second)
}

if err := c.tlsHandshakeWithTimeout(conn, ctx); err == nil {
h1 := conn.GetHandshakeLog()
enumeratedCiphers = append(enumeratedCiphers, h1.ServerHello.CipherSuite.String())
}
// Cancel context per-iteration to release timer resources immediately
if cancel != nil {
cancel()
}
_ = conn.Close() // also closes baseConn internally
}
return enumeratedCiphers, nil
Expand Down Expand Up @@ -320,20 +331,27 @@ func (c *Client) getConfig(hostname, ip, port string, options clients.ConnectOpt
return config, nil
}

// tlsHandshakeWithCtx attempts tls handshake with given timeout
// tlsHandshakeWithTimeout attempts tls handshake with given timeout.
// The handshake is executed in a goroutine to ensure the context timeout
// is properly respected and the select can choose between completion and timeout.
func (c *Client) tlsHandshakeWithTimeout(tlsConn *tls.Conn, ctx context.Context) error {
errChan := make(chan error, 1)
defer close(errChan)

// Run handshake in goroutine so the select can properly choose
// between handshake completion and context timeout
go func() {
errChan <- tlsConn.Handshake()
}()

select {
case <-ctx.Done():
// Close the connection to unblock the handshake goroutine
_ = tlsConn.Close()
return errorutil.NewWithTag("ztls", "timeout while attempting handshake") //nolint
case errChan <- tlsConn.Handshake():
}

err := <-errChan
if err == tls.ErrCertsOnly {
err = nil
case err := <-errChan:
if err == tls.ErrCertsOnly {
return nil
}
return err
}
return err
}