diff --git a/remote/log_file_security_test.go b/remote/log_file_security_test.go new file mode 100644 index 0000000..7292deb --- /dev/null +++ b/remote/log_file_security_test.go @@ -0,0 +1,262 @@ +package remote + +import ( + "bytes" + "errors" + "net" + "os" + "path/filepath" + "runtime" + "syscall" + "testing" + "time" +) + +type partialWriter struct { + maxChunk int + buf bytes.Buffer +} + +func (w *partialWriter) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + n := w.maxChunk + if n > len(p) { + n = len(p) + } + return w.buf.Write(p[:n]) +} + +type errWriter struct{} + +func (w errWriter) Write(_ []byte) (int, error) { + return 0, errors.New("boom") +} + +func TestOpenOrCreateLogFileForWrite_CreatesNewFile(t *testing.T) { + path := filepath.Join(t.TempDir(), "wp-cli-guid.log") + + f, err := openOrCreateLogFileForWrite(path) + if err != nil { + t.Fatalf("openOrCreateLogFileForWrite() error = %v", err) + } + defer f.Close() + + if _, err = f.Write([]byte("hello")); err != nil { + t.Fatalf("write() error = %v", err) + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat() error = %v", err) + } + if !info.Mode().IsRegular() { + t.Fatalf("expected regular file, got mode=%v", info.Mode()) + } +} + +func TestOpenOrCreateLogFileForWrite_TruncatesExistingRegularFile(t *testing.T) { + path := filepath.Join(t.TempDir(), "wp-cli-guid.log") + + if err := os.WriteFile(path, []byte("existing-content"), 0600); err != nil { + t.Fatalf("writeFile() setup error = %v", err) + } + + f, err := openOrCreateLogFileForWrite(path) + if err != nil { + t.Fatalf("openOrCreateLogFileForWrite() error = %v", err) + } + defer f.Close() + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat() error = %v", err) + } + if info.Size() != 0 { + t.Fatalf("expected truncated file size 0, got %d", info.Size()) + } +} + +func TestOpenOrCreateLogFileForWrite_RejectsSymlink(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink behavior differs on windows") + } + + dir := t.TempDir() + target := filepath.Join(dir, "target.log") + link := filepath.Join(dir, "wp-cli-guid.log") + + if err := os.WriteFile(target, []byte("do-not-touch"), 0600); err != nil { + t.Fatalf("writeFile() setup error = %v", err) + } + if err := os.Symlink(target, link); err != nil { + t.Fatalf("symlink() setup error = %v", err) + } + + f, err := openOrCreateLogFileForWrite(link) + if err == nil { + _ = f.Close() + t.Fatal("expected error when log path is a symlink") + } + + data, readErr := os.ReadFile(target) + if readErr != nil { + t.Fatalf("readFile() error = %v", readErr) + } + if string(data) != "do-not-touch" { + t.Fatalf("symlink target was modified, got %q", string(data)) + } +} + +func TestOpenOrCreateLogFileForWrite_RejectsFifoWithoutBlocking(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("fifo behavior differs on windows") + } + + path := filepath.Join(t.TempDir(), "wp-cli-guid.fifo") + if err := syscall.Mkfifo(path, 0600); err != nil { + t.Fatalf("mkfifo() setup error = %v", err) + } + + errCh := make(chan error, 1) + go func() { + f, err := openOrCreateLogFileForWrite(path) + if err == nil { + _ = f.Close() + } + errCh <- err + }() + + select { + case err := <-errCh: + if err == nil { + t.Fatal("expected error when log path is a fifo") + } + case <-time.After(1 * time.Second): + t.Fatal("openOrCreateLogFileForWrite() blocked on fifo") + } +} + +func TestStreamLogs_MissingFileMessage(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + streamLogs(serverConn, "guid-that-does-not-exist") + }() + + if err := clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("setReadDeadline() error = %v", err) + } + + buf := make([]byte, 512) + n, err := clientConn.Read(buf) + if err != nil { + t.Fatalf("read() error = %v", err) + } + + got := string(buf[:n]) + if got != "The WP CLI log file for GUID guid-that-does-not-exist does not exist\n" { + t.Fatalf("unexpected message: %q", got) + } + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("streamLogs did not return") + } +} + +func TestOpenLogFileForRead_RegularFile(t *testing.T) { + path := filepath.Join(t.TempDir(), "wp-cli-guid.log") + if err := os.WriteFile(path, []byte("hello"), 0600); err != nil { + t.Fatalf("writeFile() setup error = %v", err) + } + + f, err := openLogFileForRead(path) + if err != nil { + t.Fatalf("openLogFileForRead() error = %v", err) + } + defer f.Close() +} + +func TestOpenLogFileForRead_RejectsFifoWithoutBlocking(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("fifo behavior differs on windows") + } + + path := filepath.Join(t.TempDir(), "wp-cli-guid.fifo") + if err := syscall.Mkfifo(path, 0600); err != nil { + t.Fatalf("mkfifo() setup error = %v", err) + } + + errCh := make(chan error, 1) + go func() { + f, err := openLogFileForRead(path) + if err == nil { + _ = f.Close() + } + errCh <- err + }() + + select { + case err := <-errCh: + if err == nil { + t.Fatal("expected error when log path is a fifo") + } + case <-time.After(1 * time.Second): + t.Fatal("openLogFileForRead() blocked on fifo") + } +} + +func TestOpenLogFileForRead_RejectsSymlink(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink behavior differs on windows") + } + + dir := t.TempDir() + target := filepath.Join(dir, "target.log") + link := filepath.Join(dir, "wp-cli-guid.log") + + if err := os.WriteFile(target, []byte("do-not-touch"), 0600); err != nil { + t.Fatalf("writeFile() setup error = %v", err) + } + if err := os.Symlink(target, link); err != nil { + t.Fatalf("symlink() setup error = %v", err) + } + + f, err := openLogFileForRead(link) + if err == nil { + _ = f.Close() + t.Fatal("expected error when log path is a symlink") + } + + data, readErr := os.ReadFile(target) + if readErr != nil { + t.Fatalf("readFile() error = %v", readErr) + } + if string(data) != "do-not-touch" { + t.Fatalf("symlink target was modified, got %q", string(data)) + } +} + +func TestWriteAll_PartialWrites(t *testing.T) { + w := &partialWriter{maxChunk: 2} + if err := writeAll(w, []byte("abcdef")); err != nil { + t.Fatalf("writeAll() error = %v", err) + } + if got := w.buf.String(); got != "abcdef" { + t.Fatalf("writeAll() wrote %q, want %q", got, "abcdef") + } +} + +func TestWriteAll_WriteError(t *testing.T) { + err := writeAll(errWriter{}, []byte("abc")) + if err == nil { + t.Fatal("expected writeAll() to return error") + } +} diff --git a/remote/remote.go b/remote/remote.go index 95d5406..2821452 100644 --- a/remote/remote.go +++ b/remote/remote.go @@ -712,23 +712,17 @@ func runWpCliCmdRemote(conn net.Conn, GUID string, rows uint16, cols uint16, wpC log.Printf("launching %s - rows: %d, cols: %d, args: %s\n", GUID, rows, cols, strings.Join(cmdArgs, " ")) logFileName := fmt.Sprintf("/tmp/wp-cli-%s", GUID) - - if _, err := os.Stat(logFileName); nil == err { - log.Printf("runWpCliCmdRemote: Removing existing GUID logfile %s", logFileName) - os.Remove(logFileName) - } - log.Printf("Creating the logfile %s", logFileName) - logFile, err := os.OpenFile(logFileName, os.O_APPEND|os.O_WRONLY|os.O_CREATE|os.O_SYNC, 0666) + logFile, err := openOrCreateLogFileForWrite(logFileName) if nil != err { - conn.Write([]byte("unable to launch the remote WP CLI process: " + err.Error())) + conn.Write([]byte("unable to launch the remote WP CLI process")) conn.Close() return fmt.Errorf("runWpCliCmdRemote: error creating the WP CLI log file: %s", err.Error()) } watcher, err := fsnotify.NewWatcher() if err != nil { - conn.Write([]byte("unable to launch the remote WP CLI process: " + err.Error())) + conn.Write([]byte("unable to launch the remote WP CLI process")) logFile.Close() conn.Close() return fmt.Errorf("runWpCliCmdRemote: error launching the WP CLI log file watcher: %s", err.Error()) @@ -772,7 +766,7 @@ func runWpCliCmdRemote(conn net.Conn, GUID string, rows uint16, cols uint16, wpC return fmt.Errorf("runWpCliCmdRemote: error setting the WP CLI TTY to ignore CR: %s", e.Error()) } - readFile, err := os.OpenFile(logFileName, os.O_RDONLY, os.ModeCharDevice) + readFile, err := openLogFileForRead(logFileName) if nil != err { conn.Close() logFile.Close() @@ -1023,16 +1017,14 @@ func streamLogs(conn net.Conn, GUID string) { log.Printf("preparing to send the log file for GUID %s\n", GUID) logFileName = fmt.Sprintf("/tmp/wp-cli-%s", GUID) - - if _, err := os.Stat(logFileName); nil != err { - conn.Write([]byte(fmt.Sprintf("The WP CLI log file for GUID %s does not exist\n", GUID))) - log.Printf("The logfile %s does not exist\n", logFileName) - conn.Close() - return - } - - logFile, err := os.OpenFile(logFileName, os.O_RDONLY|os.O_SYNC, 0666) + logFile, err := openLogFileForRead(logFileName) if nil != err { + if os.IsNotExist(err) { + conn.Write([]byte(fmt.Sprintf("The WP CLI log file for GUID %s does not exist\n", GUID))) + log.Printf("The logfile %s does not exist\n", logFileName) + conn.Close() + return + } conn.Write([]byte("error reading the WP CLI log file\n")) log.Printf("error reading the WP CLI log file: %s\n", err.Error()) conn.Close() @@ -1046,13 +1038,113 @@ func streamLogs(conn net.Conn, GUID string) { if io.EOF == err { break } - conn.Write(buf[:read]) + if err != nil { + log.Printf("error reading the WP CLI log file: %s\n", err.Error()) + break + } + if err = writeAll(conn, buf[:read]); err != nil { + log.Printf("error writing the WP CLI log stream: %s\n", err.Error()) + break + } } conn.Close() logFile.Close() log.Printf("log file for GUID %s sent\n", GUID) } +func writeAll(w io.Writer, data []byte) error { + for len(data) > 0 { + n, err := w.Write(data) + if err != nil { + return err + } + if n <= 0 { + return io.ErrShortWrite + } + data = data[n:] + } + + return nil +} + +func openLogFileForRead(logFileName string) (*os.File, error) { + // Validate with a non-blocking open first to avoid hanging on FIFOs/special files. + logFile, err := os.OpenFile(logFileName, os.O_RDONLY|os.O_SYNC|syscall.O_NOFOLLOW|syscall.O_NONBLOCK, 0) + if err != nil { + return nil, err + } + + fileInfo, err := logFile.Stat() + if err != nil { + _ = logFile.Close() + return nil, err + } + + if !fileInfo.Mode().IsRegular() { + _ = logFile.Close() + return nil, fmt.Errorf("log file path %q is not a regular file", logFileName) + } + + if err := syscall.SetNonblock(int(logFile.Fd()), false); err != nil { + _ = logFile.Close() + return nil, err + } + + return logFile, nil +} + +func openOrCreateLogFileForWrite(logFileName string) (*os.File, error) { + // Use atomic create without following symlinks to avoid TOCTOU attacks on /tmp paths. + logFile, err := os.OpenFile(logFileName, os.O_APPEND|os.O_WRONLY|os.O_CREATE|os.O_EXCL|os.O_SYNC|syscall.O_NOFOLLOW, 0600) + if err == nil { + return logFile, nil + } + + if !os.IsExist(err) { + return nil, err + } + + // Compatibility: reuse existing GUID files, but validate safely before truncating. + logFile, err = os.OpenFile(logFileName, os.O_APPEND|os.O_WRONLY|os.O_SYNC|syscall.O_NOFOLLOW|syscall.O_NONBLOCK, 0) + if err != nil { + return nil, err + } + + fileInfo, err := logFile.Stat() + if err != nil { + _ = logFile.Close() + return nil, err + } + + if !fileInfo.Mode().IsRegular() { + _ = logFile.Close() + return nil, fmt.Errorf("log file path %q is not a regular file", logFileName) + } + + stat, ok := fileInfo.Sys().(*syscall.Stat_t) + if !ok { + _ = logFile.Close() + return nil, fmt.Errorf("could not determine owner for log file path %q", logFileName) + } + + if int(stat.Uid) != os.Geteuid() { + _ = logFile.Close() + return nil, fmt.Errorf("log file path %q is not owned by current user", logFileName) + } + + if err := syscall.SetNonblock(int(logFile.Fd()), false); err != nil { + _ = logFile.Close() + return nil, err + } + + if err := logFile.Truncate(0); err != nil { + _ = logFile.Close() + return nil, err + } + + return logFile, nil +} + /* Splits a string into an array based on whitespace except when that whitepace is inside double qoutes or escaped quotes */