diff --git a/frontend/src/libs/index.ts b/frontend/src/libs/index.ts index 523923ab4b..f8a70f146e 100644 --- a/frontend/src/libs/index.ts +++ b/frontend/src/libs/index.ts @@ -91,15 +91,6 @@ export const riseRouterException = (status = 404, json = 'Not Found'): never => throw new Response(json, { status }); }; -export const base64ToArrayBuffer = (base64: string) => { - const binaryString = atob(base64); - const bytes = new Uint8Array(binaryString.length); - for (let i = 0; i < binaryString.length; i++) { - bytes[i] = binaryString.charCodeAt(i); - } - return bytes; -}; - export const isValidUrl = (urlString: string) => { try { return Boolean(new URL(urlString)); diff --git a/frontend/src/pages/Runs/Details/Logs/index.tsx b/frontend/src/pages/Runs/Details/Logs/index.tsx index aaffbf41cc..6fc5501039 100644 --- a/frontend/src/pages/Runs/Details/Logs/index.tsx +++ b/frontend/src/pages/Runs/Details/Logs/index.tsx @@ -31,7 +31,7 @@ export const Logs: React.FC = ({ className, projectName, runName, jobSub const writeDataToTerminal = (logs: ILogItem[]) => { logs.forEach((logItem) => { - terminalInstance.current.write(logItem.message); + terminalInstance.current.write(logItem.message.replace(/(? { const logs = response.logs.map((logItem) => ({ ...logItem, - message: base64ToArrayBuffer(logItem.message as string), + message: logItem.message, })); return { diff --git a/frontend/src/types/log.d.ts b/frontend/src/types/log.d.ts index eec182c1ec..99e9532c8c 100644 --- a/frontend/src/types/log.d.ts +++ b/frontend/src/types/log.d.ts @@ -1,7 +1,7 @@ declare interface ILogItem { log_source: 'stdout' | 'stderr'; timestamp: string; - message: string | Uint8Array; + message: string; } declare type TRequestLogsParams = { diff --git a/runner/go.mod b/runner/go.mod index 22dad6466e..c619383bd0 100644 --- a/runner/go.mod +++ b/runner/go.mod @@ -1,6 +1,6 @@ module github.com/dstackai/dstack/runner -go 1.23 +go 1.23.8 require ( github.com/alexellis/go-execute/v2 v2.2.1 @@ -10,6 +10,7 @@ require ( github.com/docker/docker v26.0.0+incompatible github.com/docker/go-connections v0.5.0 github.com/docker/go-units v0.5.0 + github.com/dstackai/ansistrip v0.0.6 github.com/go-git/go-git/v5 v5.12.0 github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f github.com/gorilla/websocket v1.5.1 @@ -62,6 +63,7 @@ require ( github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.2.2 // indirect + github.com/tidwall/btree v1.7.0 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/ulikunitz/xz v0.5.12 // indirect diff --git a/runner/go.sum b/runner/go.sum index 41e133c465..801974184f 100644 --- a/runner/go.sum +++ b/runner/go.sum @@ -47,6 +47,8 @@ github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dstackai/ansistrip v0.0.6 h1:6qqeDNWt8NoqfkY1CxKUvdHpJzBl89LOE3wMwptVpaI= +github.com/dstackai/ansistrip v0.0.6/go.mod h1:w3ejXI0twxDv6bPXhkOaPeYdbwz2nwcrcvFoZGqi9F0= github.com/ebitengine/purego v0.8.1 h1:sdRKd6plj7KYW33EH5As6YKfe8m9zbN9JMrOjNVF/BE= github.com/ebitengine/purego v0.8.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a h1:mATvB/9r/3gvcejNsXKSkQ6lcIaNec2nyfOdlTBR2lU= @@ -171,6 +173,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= +github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= diff --git a/runner/internal/executor/base.go b/runner/internal/executor/base.go index 2163ca9204..554bd7646a 100644 --- a/runner/internal/executor/base.go +++ b/runner/internal/executor/base.go @@ -10,7 +10,7 @@ import ( type Executor interface { GetHistory(timestamp int64) *schemas.PullResponse - GetJobLogsHistory() []schemas.LogEvent + GetJobWsLogsHistory() []schemas.LogEvent GetRunnerState() string Run(ctx context.Context) error SetCodePath(codePath string) diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index c29913dcb1..d2ced5dc3a 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -19,6 +19,7 @@ import ( "time" "github.com/creack/pty" + "github.com/dstackai/ansistrip" "github.com/dstackai/dstack/runner/consts" "github.com/dstackai/dstack/runner/internal/connections" "github.com/dstackai/dstack/runner/internal/gerrors" @@ -28,6 +29,18 @@ import ( "github.com/prometheus/procfs" ) +// TODO: Tune these parameters for optimal experience/performance +const ( + // Output is flushed when the cursor doesn't move for this duration + AnsiStripFlushInterval = 500 * time.Millisecond + + // Output is flushed regardless of cursor activity after this maximum delay + AnsiStripMaxDelay = 3 * time.Second + + // Maximum buffer size for ansistrip + MaxBufferSize = 32 * 1024 // 32KB +) + type ConnectionTracker interface { GetNoConnectionsSecs() int64 Track(ticker <-chan time.Time) @@ -54,6 +67,7 @@ type RunExecutor struct { state string jobStateHistory []schemas.JobStateEvent jobLogs *appendWriter + jobWsLogs *appendWriter runnerLogs *appendWriter timestamp *MonotonicTimestamp @@ -110,6 +124,7 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string, sshPort i state: WaitSubmit, jobStateHistory: make([]schemas.JobStateEvent, 0), jobLogs: newAppendWriter(mu, timestamp), + jobWsLogs: newAppendWriter(mu, timestamp), runnerLogs: newAppendWriter(mu, timestamp), timestamp: timestamp, @@ -153,7 +168,9 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { } }() - logger := io.MultiWriter(runnerLogFile, os.Stdout, ex.runnerLogs) + stripper := ansistrip.NewWriter(ex.runnerLogs, AnsiStripFlushInterval, AnsiStripMaxDelay, MaxBufferSize) + defer stripper.Close() + logger := io.MultiWriter(runnerLogFile, os.Stdout, stripper) ctx = log.WithLogger(ctx, log.NewEntry(logger, int(log.DefaultEntry.Logger.Level))) // todo loglevel log.Info(ctx, "Run job", "log_level", log.GetLogger(ctx).Logger.Level.String()) @@ -455,7 +472,9 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error defer func() { _ = ptm.Close() }() defer func() { _ = cmd.Wait() }() // release resources if copy fails - logger := io.MultiWriter(jobLogFile, ex.jobLogs) + stripper := ansistrip.NewWriter(ex.jobLogs, AnsiStripFlushInterval, AnsiStripMaxDelay, MaxBufferSize) + defer stripper.Close() + logger := io.MultiWriter(jobLogFile, ex.jobWsLogs, stripper) _, err = io.Copy(logger, ptm) if err != nil && !isPtyError(err) { return gerrors.Wrap(err) diff --git a/runner/internal/executor/executor_test.go b/runner/internal/executor/executor_test.go index 8d275b1375..e13184513f 100644 --- a/runner/internal/executor/executor_test.go +++ b/runner/internal/executor/executor_test.go @@ -9,6 +9,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "testing" "time" @@ -17,8 +18,6 @@ import ( "github.com/stretchr/testify/require" ) -// todo test get history - func TestExecutor_WorkingDir_Current(t *testing.T) { var b bytes.Buffer ex := makeTestExecutor(t) @@ -28,7 +27,8 @@ func TestExecutor_WorkingDir_Current(t *testing.T) { err := ex.execJob(context.TODO(), io.Writer(&b)) assert.NoError(t, err) - assert.Equal(t, ex.workingDir+"\r\n", b.String()) + // Normalize line endings for cross-platform compatibility. + assert.Equal(t, ex.workingDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) } func TestExecutor_WorkingDir_Nil(t *testing.T) { @@ -39,7 +39,7 @@ func TestExecutor_WorkingDir_Nil(t *testing.T) { err := ex.execJob(context.TODO(), io.Writer(&b)) assert.NoError(t, err) - assert.Equal(t, ex.workingDir+"\r\n", b.String()) + assert.Equal(t, ex.workingDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) } func TestExecutor_HomeDir(t *testing.T) { @@ -49,7 +49,7 @@ func TestExecutor_HomeDir(t *testing.T) { err := ex.execJob(context.TODO(), io.Writer(&b)) assert.NoError(t, err) - assert.Equal(t, ex.homeDir+"\r\n", b.String()) + assert.Equal(t, ex.homeDir+"\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) } func TestExecutor_NonZeroExit(t *testing.T) { @@ -61,7 +61,7 @@ func TestExecutor_NonZeroExit(t *testing.T) { assert.Error(t, err) assert.NotEmpty(t, ex.jobStateHistory) exitStatus := ex.jobStateHistory[len(ex.jobStateHistory)-1].ExitStatus - assert.NotNil(t, exitStatus, ex.jobStateHistory) + assert.NotNil(t, exitStatus) assert.Equal(t, 100, *exitStatus) } @@ -96,7 +96,7 @@ func TestExecutor_LocalRepo(t *testing.T) { err = ex.execJob(context.TODO(), io.Writer(&b)) assert.NoError(t, err) - assert.Equal(t, "bar\r\n", b.String()) + assert.Equal(t, "bar\n", strings.ReplaceAll(b.String(), "\r\n", "\n")) } func TestExecutor_Recover(t *testing.T) { @@ -148,8 +148,8 @@ func TestExecutor_RemoteRepo(t *testing.T) { err = ex.execJob(context.TODO(), io.Writer(&b)) assert.NoError(t, err) - expected := fmt.Sprintf("%s\r\n%s\r\n%s\r\n", ex.getRepoData().RepoHash, ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail) - assert.Equal(t, expected, b.String()) + expected := fmt.Sprintf("%s\n%s\n%s\n", ex.getRepoData().RepoHash, ex.getRepoData().RepoConfigName, ex.getRepoData().RepoConfigEmail) + assert.Equal(t, expected, strings.ReplaceAll(b.String(), "\r\n", "\n")) } /* Helpers */ @@ -236,3 +236,98 @@ func TestWriteDstackProfile(t *testing.T) { assert.Equal(t, value, string(out)) } } + +func TestExecutor_Logs(t *testing.T) { + var b bytes.Buffer + ex := makeTestExecutor(t) + // Use printf to generate ANSI control codes. + // \033[31m = red text, \033[1;32m = bold green text, \033[0m = reset + ex.jobSpec.Commands = append(ex.jobSpec.Commands, "printf '\\033[31mRed Hello World\\033[0m\\n' && printf '\\033[1;32mBold Green Line 2\\033[0m\\n' && printf 'Line 3\\n'") + + err := ex.execJob(context.TODO(), io.Writer(&b)) + assert.NoError(t, err) + + logHistory := ex.GetHistory(0).JobLogs + assert.NotEmpty(t, logHistory) + + logString := combineLogMessages(logHistory) + normalizedLogString := strings.ReplaceAll(logString, "\r\n", "\n") + + expectedOutput := "Red Hello World\nBold Green Line 2\nLine 3\n" + assert.Equal(t, expectedOutput, normalizedLogString, "Should strip ANSI codes from regular logs") + + // Verify timestamps are in order + assert.Greater(t, len(logHistory), 0) + for i := 1; i < len(logHistory); i++ { + assert.GreaterOrEqual(t, logHistory[i].Timestamp, logHistory[i-1].Timestamp) + } +} + +func TestExecutor_LogsWithErrors(t *testing.T) { + var b bytes.Buffer + ex := makeTestExecutor(t) + ex.jobSpec.Commands = append(ex.jobSpec.Commands, "echo 'Success message' && echo 'Error message' >&2 && exit 1") + + err := ex.execJob(context.TODO(), io.Writer(&b)) + assert.Error(t, err) + + logHistory := ex.GetHistory(0).JobLogs + assert.NotEmpty(t, logHistory) + + logString := combineLogMessages(logHistory) + normalizedLogString := strings.ReplaceAll(logString, "\r\n", "\n") + + expectedOutput := "Success message\nError message\n" + assert.Equal(t, expectedOutput, normalizedLogString) +} + +func TestExecutor_LogsAnsiCodeHandling(t *testing.T) { + var b bytes.Buffer + ex := makeTestExecutor(t) + + // Test a variety of ANSI escape sequences on stdout and stderr. + cmd := "printf '\\033[31mRed\\033[0m \\033[32mGreen\\033[0m\\n' && " + + "printf '\\033[1mBold\\033[0m \\033[4mUnderline\\033[0m\\n' && " + + "printf '\\033[s\\033[uPlain text\\n' >&2" + + ex.jobSpec.Commands = append(ex.jobSpec.Commands, cmd) + + err := ex.execJob(context.TODO(), io.Writer(&b)) + assert.NoError(t, err) + + // 1. Check WebSocket logs, which should preserve ANSI codes. + wsLogHistory := ex.GetJobWsLogsHistory() + assert.NotEmpty(t, wsLogHistory) + wsLogString := combineLogMessages(wsLogHistory) + normalizedWsLogString := strings.ReplaceAll(wsLogString, "\r\n", "\n") + + expectedWsOutput := "\033[31mRed\033[0m \033[32mGreen\033[0m\n" + + "\033[1mBold\033[0m \033[4mUnderline\033[0m\n" + + "\033[s\033[uPlain text\n" + assert.Equal(t, expectedWsOutput, normalizedWsLogString, "Websocket logs should preserve ANSI codes") + + // 2. Check regular job logs, which should have ANSI codes stripped. + regularLogHistory := ex.GetHistory(0).JobLogs + assert.NotEmpty(t, regularLogHistory) + regularLogString := combineLogMessages(regularLogHistory) + normalizedRegularLogString := strings.ReplaceAll(regularLogString, "\r\n", "\n") + + expectedRegularOutput := "Red Green\n" + + "Bold Underline\n" + + "Plain text\n" + assert.Equal(t, expectedRegularOutput, normalizedRegularLogString, "Regular logs should have ANSI codes stripped") + + // Verify timestamps are ordered for both log types. + assert.Greater(t, len(wsLogHistory), 0) + for i := 1; i < len(wsLogHistory); i++ { + assert.GreaterOrEqual(t, wsLogHistory[i].Timestamp, wsLogHistory[i-1].Timestamp) + } +} + +func combineLogMessages(logHistory []schemas.LogEvent) string { + var logOutput bytes.Buffer + for _, logEvent := range logHistory { + logOutput.Write(logEvent.Message) + } + return logOutput.String() +} diff --git a/runner/internal/executor/query.go b/runner/internal/executor/query.go index 1dff4e330c..6678e5f8d7 100644 --- a/runner/internal/executor/query.go +++ b/runner/internal/executor/query.go @@ -4,8 +4,8 @@ import ( "github.com/dstackai/dstack/runner/internal/schemas" ) -func (ex *RunExecutor) GetJobLogsHistory() []schemas.LogEvent { - return ex.jobLogs.history +func (ex *RunExecutor) GetJobWsLogsHistory() []schemas.LogEvent { + return ex.jobWsLogs.history } func (ex *RunExecutor) GetHistory(timestamp int64) *schemas.PullResponse { diff --git a/runner/internal/runner/api/ws.go b/runner/internal/runner/api/ws.go index cade1170a2..ebb0caea22 100644 --- a/runner/internal/runner/api/ws.go +++ b/runner/internal/runner/api/ws.go @@ -34,23 +34,23 @@ func (s *Server) streamJobLogs(conn *websocket.Conn) { for { s.executor.RLock() - jobLogsHistory := s.executor.GetJobLogsHistory() + jobLogsWsHistory := s.executor.GetJobWsLogsHistory() select { case <-s.shutdownCh: - if currentPos >= len(jobLogsHistory) { + if currentPos >= len(jobLogsWsHistory) { s.executor.RUnlock() close(s.wsDoneCh) return } default: - if currentPos >= len(jobLogsHistory) { + if currentPos >= len(jobLogsWsHistory) { s.executor.RUnlock() time.Sleep(100 * time.Millisecond) continue } } - for currentPos < len(jobLogsHistory) { - if err := conn.WriteMessage(websocket.BinaryMessage, jobLogsHistory[currentPos].Message); err != nil { + for currentPos < len(jobLogsWsHistory) { + if err := conn.WriteMessage(websocket.BinaryMessage, jobLogsWsHistory[currentPos].Message); err != nil { s.executor.RUnlock() log.Error(context.TODO(), "Failed to write message", "err", err) return diff --git a/src/dstack/_internal/server/services/logs/aws.py b/src/dstack/_internal/server/services/logs/aws.py index 92155763c3..616db94dbb 100644 --- a/src/dstack/_internal/server/services/logs/aws.py +++ b/src/dstack/_internal/server/services/logs/aws.py @@ -17,7 +17,6 @@ from dstack._internal.server.services.logs.base import ( LogStorage, LogStorageError, - b64encode_raw_message, datetime_to_unix_time_ms, unix_time_ms_to_datetime, ) @@ -238,8 +237,7 @@ def _get_next_batch( skipped_future_events += 1 continue cw_event = self._runner_log_event_to_cloudwatch_event(event) - # as message is base64-encoded, length in bytes = length in code points. - message_size = len(cw_event["message"]) + self.MESSAGE_OVERHEAD_SIZE + message_size = len(event.message) + self.MESSAGE_OVERHEAD_SIZE if message_size > self.MESSAGE_MAX_SIZE: # we should never hit this limit, as we use `io.Copy` to copy from pty to logs, # which under the hood uses 32KiB buffer, see runner/internal/executor/executor.go, @@ -271,7 +269,7 @@ def _runner_log_event_to_cloudwatch_event( ) -> _CloudWatchLogEvent: return { "timestamp": runner_log_event.timestamp, - "message": b64encode_raw_message(runner_log_event.message), + "message": runner_log_event.message.decode(errors="replace"), } @contextmanager diff --git a/src/dstack/_internal/server/services/logs/filelog.py b/src/dstack/_internal/server/services/logs/filelog.py index b1c4ce2bc0..10cbe3d1a2 100644 --- a/src/dstack/_internal/server/services/logs/filelog.py +++ b/src/dstack/_internal/server/services/logs/filelog.py @@ -15,7 +15,6 @@ from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent from dstack._internal.server.services.logs.base import ( LogStorage, - b64encode_raw_message, unix_time_ms_to_datetime, ) @@ -56,32 +55,42 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi current_line = 0 try: - # FIXME: Do not read all the lines in memory with open(log_file_path) as f: - lines = f.readlines() - except FileNotFoundError: - pass - else: - for i, line in enumerate(lines): - if current_line < start_line: + # Skip to start_line if needed + for _ in range(start_line): + if f.readline() == "": + # File is shorter than start_line + return JobSubmissionLogs(logs=logs, next_token=next_token) + current_line += 1 + + # Read lines one by one + while True: + line = f.readline() + if line == "": # EOF + break + current_line += 1 - continue - log_event = LogEvent.__response__.parse_raw(line) - current_line += 1 + try: + log_event = LogEvent.__response__.parse_raw(line) + except Exception: + # Skip malformed lines + continue - if request.start_time and log_event.timestamp <= request.start_time: - continue - if request.end_time is not None and log_event.timestamp >= request.end_time: - break + if request.start_time and log_event.timestamp <= request.start_time: + continue + if request.end_time is not None and log_event.timestamp >= request.end_time: + break - logs.append(log_event) + logs.append(log_event) - if len(logs) >= request.limit: - # Only set next_token if there are more lines to read - if current_line < len(lines): - next_token = str(current_line) - break + if len(logs) >= request.limit: + # Check if there are more lines to read + if f.readline() != "": + next_token = str(current_line) + break + except FileNotFoundError: + pass return JobSubmissionLogs(logs=logs, next_token=next_token) @@ -137,5 +146,5 @@ def _runner_log_event_to_log_event(self, runner_log_event: RunnerLogEvent) -> Lo return LogEvent( timestamp=unix_time_ms_to_datetime(runner_log_event.timestamp), log_source=LogEventSource.STDOUT, - message=b64encode_raw_message(runner_log_event.message), + message=runner_log_event.message.decode(errors="replace"), ) diff --git a/src/dstack/_internal/server/services/logs/gcp.py b/src/dstack/_internal/server/services/logs/gcp.py index ac228e19e1..6e9314df26 100644 --- a/src/dstack/_internal/server/services/logs/gcp.py +++ b/src/dstack/_internal/server/services/logs/gcp.py @@ -14,7 +14,6 @@ from dstack._internal.server.services.logs.base import ( LogStorage, LogStorageError, - b64encode_raw_message, unix_time_ms_to_datetime, ) from dstack._internal.utils.common import batched @@ -137,15 +136,14 @@ def _write_logs_to_stream(self, stream_name: str, logs: List[RunnerLogEvent]): with self.logger.batch() as batcher: for batch in batched(logs, self.MAX_BATCH_SIZE): for log in batch: - message = b64encode_raw_message(log.message) + message = log.message.decode(errors="replace") timestamp = unix_time_ms_to_datetime(log.timestamp) - # as message is base64-encoded, length in bytes = length in code points - if len(message) > self.MAX_RUNNER_MESSAGE_SIZE: + if len(log.message) > self.MAX_RUNNER_MESSAGE_SIZE: logger.error( "Stream %s: skipping event at %s, message exceeds max size: %d > %d", stream_name, timestamp.isoformat(), - len(message), + len(log.message), self.MAX_RUNNER_MESSAGE_SIZE, ) continue diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index e1992068d0..6a19cb459d 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -1,4 +1,3 @@ -import base64 import queue import tempfile import threading @@ -229,7 +228,7 @@ def logs( ), ) for log in resp.logs: - yield base64.b64decode(log.message) + yield log.message.encode() next_token = resp.next_token if next_token is None: break diff --git a/src/tests/_internal/server/services/test_logs.py b/src/tests/_internal/server/services/test_logs.py index 3ece736a7c..0b94209175 100644 --- a/src/tests/_internal/server/services/test_logs.py +++ b/src/tests/_internal/server/services/test_logs.py @@ -1,4 +1,3 @@ -import base64 import logging from datetime import datetime, timedelta, timezone from pathlib import Path @@ -52,8 +51,8 @@ async def test_writes_logs(self, test_db, session: AsyncSession, tmp_path: Path) / "runner.log" ) assert runner_log_path.read_text() == ( - '{"timestamp":"2023-10-06T10:01:53.234000+00:00","log_source":"stdout","message":"SGVsbG8="}\n' - '{"timestamp":"2023-10-06T10:01:53.235000+00:00","log_source":"stdout","message":"V29ybGQ="}\n' + '{"timestamp":"2023-10-06T10:01:53.234000+00:00","log_source":"stdout","message":"Hello"}\n' + '{"timestamp":"2023-10-06T10:01:53.235000+00:00","log_source":"stdout","message":"World"}\n' ) @pytest.mark.asyncio @@ -120,12 +119,8 @@ async def test_poll_logs_with_next_token_pagination( job_submission_logs = log_storage.poll_logs(project, poll_request) assert len(job_submission_logs.logs) == 2 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log1".encode("utf-8") - ).decode("utf-8") - assert job_submission_logs.logs[1].message == base64.b64encode( - "Log2".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log1" + assert job_submission_logs.logs[1].message == "Log2" assert job_submission_logs.next_token == "2" # Next line to read # Second page: use next_token @@ -133,12 +128,8 @@ async def test_poll_logs_with_next_token_pagination( job_submission_logs = log_storage.poll_logs(project, poll_request) assert len(job_submission_logs.logs) == 2 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log3".encode("utf-8") - ).decode("utf-8") - assert job_submission_logs.logs[1].message == base64.b64encode( - "Log4".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log3" + assert job_submission_logs.logs[1].message == "Log4" assert job_submission_logs.next_token == "4" # Next line to read # Third page: get remaining log @@ -146,9 +137,7 @@ async def test_poll_logs_with_next_token_pagination( job_submission_logs = log_storage.poll_logs(project, poll_request) assert len(job_submission_logs.logs) == 1 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log5".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log5" assert job_submission_logs.next_token is None # No more logs @pytest.mark.asyncio @@ -183,12 +172,8 @@ async def test_poll_logs_with_start_from_specific_line( job_submission_logs = log_storage.poll_logs(project, poll_request) assert len(job_submission_logs.logs) == 2 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log2".encode("utf-8") - ).decode("utf-8") - assert job_submission_logs.logs[1].message == base64.b64encode( - "Log3".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log2" + assert job_submission_logs.logs[1].message == "Log3" assert job_submission_logs.next_token is None @pytest.mark.asyncio @@ -279,9 +264,7 @@ async def test_poll_logs_with_time_filtering_and_pagination( # Should get Log3 first (timestamp > 235) assert len(job_submission_logs.logs) == 1 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log3".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log3" assert job_submission_logs.next_token == "3" # Get next page @@ -290,9 +273,7 @@ async def test_poll_logs_with_time_filtering_and_pagination( # Should get Log4 assert len(job_submission_logs.logs) == 1 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log4".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log4" # Should not have next_token since we reached end of file assert job_submission_logs.next_token is None @@ -359,9 +340,9 @@ async def test_next_token_pagination_complete_workflow( page1 = log_storage.poll_logs(project, poll_request) assert len(page1.logs) == 3 - assert page1.logs[0].message == base64.b64encode("Log1".encode()).decode() - assert page1.logs[1].message == base64.b64encode("Log2".encode()).decode() - assert page1.logs[2].message == base64.b64encode("Log3".encode()).decode() + assert page1.logs[0].message == "Log1" + assert page1.logs[1].message == "Log2" + assert page1.logs[2].message == "Log3" assert page1.next_token == "3" # Next line to read # Second page: use next_token @@ -369,9 +350,9 @@ async def test_next_token_pagination_complete_workflow( page2 = log_storage.poll_logs(project, poll_request) assert len(page2.logs) == 3 - assert page2.logs[0].message == base64.b64encode("Log4".encode()).decode() - assert page2.logs[1].message == base64.b64encode("Log5".encode()).decode() - assert page2.logs[2].message == base64.b64encode("Log6".encode()).decode() + assert page2.logs[0].message == "Log4" + assert page2.logs[1].message == "Log5" + assert page2.logs[2].message == "Log6" assert page2.next_token == "6" # Third page: get more logs @@ -379,9 +360,9 @@ async def test_next_token_pagination_complete_workflow( page3 = log_storage.poll_logs(project, poll_request) assert len(page3.logs) == 3 - assert page3.logs[0].message == base64.b64encode("Log7".encode()).decode() - assert page3.logs[1].message == base64.b64encode("Log8".encode()).decode() - assert page3.logs[2].message == base64.b64encode("Log9".encode()).decode() + assert page3.logs[0].message == "Log7" + assert page3.logs[1].message == "Log8" + assert page3.logs[2].message == "Log9" assert page3.next_token == "9" # Fourth page: get last log @@ -389,8 +370,8 @@ async def test_next_token_pagination_complete_workflow( page4 = log_storage.poll_logs(project, poll_request) assert len(page4.logs) == 1 - assert page4.logs[0].message == base64.b64encode("Log10".encode()).decode() - assert page4.next_token is None # No more logs + assert page4.logs[0].message == "Log10" + assert page4.next_token is None @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -428,15 +409,15 @@ async def test_next_token_with_time_filtering( page1 = log_storage.poll_logs(project, poll_request) assert len(page1.logs) == 2 - assert page1.logs[0].message == base64.b64encode("Log3".encode()).decode() - assert page1.logs[1].message == base64.b64encode("Log4".encode()).decode() + assert page1.logs[0].message == "Log3" + assert page1.logs[1].message == "Log4" assert page1.next_token == "4" # Get next page poll_request.next_token = page1.next_token page2 = log_storage.poll_logs(project, poll_request) assert len(page2.logs) == 1 - assert page2.logs[0].message == base64.b64encode("Log5".encode()).decode() + assert page2.logs[0].message == "Log5" assert page2.next_token is None @pytest.mark.asyncio @@ -467,16 +448,16 @@ async def test_next_token_edge_cases(self, test_db, session: AsyncSession, tmp_p result = log_storage.poll_logs(project, poll_request) assert len(result.logs) == 1 - assert result.logs[0].message == base64.b64encode("OnlyLog".encode()).decode() - assert result.next_token is None # No more logs available + assert result.logs[0].message == "OnlyLog" + assert result.next_token is None # Request with limit equal to available logs poll_request.limit = 1 result = log_storage.poll_logs(project, poll_request) assert len(result.logs) == 1 - assert result.logs[0].message == base64.b64encode("OnlyLog".encode()).decode() - assert result.next_token is None # No more logs available + assert result.logs[0].message == "OnlyLog" + assert result.next_token is None @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -577,15 +558,9 @@ async def test_poll_logs_with_limit(self, test_db, session: AsyncSession, tmp_pa # Should return only the first 3 logs and provide next_token assert len(job_submission_logs.logs) == 3 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log1".encode("utf-8") - ).decode("utf-8") - assert job_submission_logs.logs[1].message == base64.b64encode( - "Log2".encode("utf-8") - ).decode("utf-8") - assert job_submission_logs.logs[2].message == base64.b64encode( - "Log3".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log1" + assert job_submission_logs.logs[1].message == "Log2" + assert job_submission_logs.logs[2].message == "Log3" # Should have next_token pointing to line 3 (fourth log) assert job_submission_logs.next_token == "3" @@ -594,9 +569,7 @@ async def test_poll_logs_with_limit(self, test_db, session: AsyncSession, tmp_pa poll_request.start_time = logs[3].timestamp job_submission_logs = log_storage.poll_logs(project, poll_request) assert len(job_submission_logs.logs) == 1 - assert job_submission_logs.logs[0].message == base64.b64encode( - "Log5".encode("utf-8") - ).decode("utf-8") + assert job_submission_logs.logs[0].message == "Log5" # Should not have next_token since we reached end of file assert job_submission_logs.next_token is None @@ -921,14 +894,14 @@ async def test_write_logs( logGroupName="test-group", logStreamName=expected_runner_stream, logEvents=[ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513234, "message": "Hello"}, ], ), call( logGroupName="test-group", logStreamName=expected_job_stream, logEvents=[ - {"timestamp": 1696586513235, "message": "V29ybGQ="}, + {"timestamp": 1696586513235, "message": "World"}, ], ), ] @@ -1026,11 +999,11 @@ async def test_write_logs_not_in_chronological_order( logGroupName="test-group", logStreamName="test-proj/test-run/1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e/runner", logEvents=[ - {"timestamp": 1696586513235, "message": "MQ=="}, - {"timestamp": 1696586513236, "message": "Mg=="}, - {"timestamp": 1696586513237, "message": "Mw=="}, - {"timestamp": 1696586513237, "message": "NA=="}, - {"timestamp": 1696586513237, "message": "NQ=="}, + {"timestamp": 1696586513235, "message": "1"}, + {"timestamp": 1696586513236, "message": "2"}, + {"timestamp": 1696586513237, "message": "3"}, + {"timestamp": 1696586513237, "message": "4"}, + {"timestamp": 1696586513237, "message": "5"}, ], ) assert "events are not in chronological order" in caplog.text @@ -1068,7 +1041,7 @@ def _delta_ms(**kwargs: int) -> int: assert "skipping 1 past event(s)" in caplog.text assert "skipping 2 future event(s)" in caplog.text actual = [ - base64.b64decode(e["message"]).decode() + e["message"] for c in mock_client.put_log_events.call_args_list for e in c.kwargs["logEvents"] ] @@ -1113,8 +1086,8 @@ async def test_write_logs_batching_by_size( messages: List[str], expected: List[List[str]], ): - # maximum 6 bytes: 12 (in base64) + 26 (overhead) = 34 - monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 34) + # maximum 6 bytes: 6 (raw bytes) + 26 (overhead) = 32 + monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 32) monkeypatch.setattr(CloudWatchLogStorage, "BATCH_MAX_SIZE", 60) log_storage.write_logs( project=project, @@ -1128,7 +1101,7 @@ async def test_write_logs_batching_by_size( ) assert mock_client.put_log_events.call_count == len(expected) actual = [ - [base64.b64decode(e["message"]).decode() for e in c.kwargs["logEvents"]] + [e["message"] for e in c.kwargs["logEvents"]] for c in mock_client.put_log_events.call_args_list ] assert actual == expected @@ -1143,7 +1116,7 @@ async def test_write_logs_batching_by_size( [["111", "111", "111"], ["222"]], ], [ - ["111", "111", "111"] + ["222", "222", "toolong", "", "222222"], + ["111", "111", "111"] + ["222", "222", "toolongtoolong", "", "222222"], [["111", "111", "111"], ["222", "222", "222222"]], ], ], @@ -1160,8 +1133,8 @@ async def test_write_logs_batching_by_count( messages: List[str], expected: List[List[str]], ): - # maximum 6 bytes: 12 (in base64) + 26 (overhead) = 34 - monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 34) + # maximum 6 bytes: 6 (raw bytes) + 26 (overhead) = 32 + monkeypatch.setattr(CloudWatchLogStorage, "MESSAGE_MAX_SIZE", 32) monkeypatch.setattr(CloudWatchLogStorage, "EVENT_MAX_COUNT_IN_BATCH", 3) log_storage.write_logs( project=project, @@ -1175,7 +1148,7 @@ async def test_write_logs_batching_by_count( ) assert mock_client.put_log_events.call_count == len(expected) actual = [ - [base64.b64decode(e["message"]).decode() for e in c.kwargs["logEvents"]] + [e["message"] for e in c.kwargs["logEvents"]] for c in mock_client.put_log_events.call_args_list ] assert actual == expected @@ -1218,7 +1191,7 @@ def _delta_ms(**kwargs: int) -> int: expected = [["1", "2", "3"], ["4", "5", "6"], ["7"]] assert mock_client.put_log_events.call_count == len(expected) actual = [ - [base64.b64decode(e["message"]).decode() for e in c.kwargs["logEvents"]] + [e["message"] for e in c.kwargs["logEvents"]] for c in mock_client.put_log_events.call_args_list ] assert actual == expected @@ -1232,8 +1205,8 @@ async def test_poll_logs_non_empty_response( poll_logs_request: PollLogsRequest, ): mock_client.get_log_events.return_value["events"] = [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - {"timestamp": 1696586513235, "message": "V29ybGQ="}, + {"timestamp": 1696586513234, "message": "Hello"}, + {"timestamp": 1696586513235, "message": "World"}, ] poll_logs_request.limit = 2 job_submission_logs = log_storage.poll_logs(project, poll_logs_request) @@ -1242,12 +1215,12 @@ async def test_poll_logs_non_empty_response( LogEvent( timestamp=datetime(2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc), log_source=LogEventSource.STDOUT, - message="SGVsbG8=", + message="Hello", ), LogEvent( timestamp=datetime(2023, 10, 6, 10, 1, 53, 235000, tzinfo=timezone.utc), log_source=LogEventSource.STDOUT, - message="V29ybGQ=", + message="World", ), ] @@ -1260,8 +1233,8 @@ async def test_poll_logs_descending_non_empty_response_on_first_call( poll_logs_request: PollLogsRequest, ): mock_client.get_log_events.return_value["events"] = [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - {"timestamp": 1696586513235, "message": "V29ybGQ="}, + {"timestamp": 1696586513234, "message": "Hello"}, + {"timestamp": 1696586513235, "message": "World"}, ] poll_logs_request.descending = True poll_logs_request.limit = 2 @@ -1271,12 +1244,12 @@ async def test_poll_logs_descending_non_empty_response_on_first_call( LogEvent( timestamp=datetime(2023, 10, 6, 10, 1, 53, 235000, tzinfo=timezone.utc), log_source=LogEventSource.STDOUT, - message="V29ybGQ=", + message="World", ), LogEvent( timestamp=datetime(2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc), log_source=LogEventSource.STDOUT, - message="SGVsbG8=", + message="Hello", ), ] @@ -1292,8 +1265,8 @@ async def test_next_token_ascending_pagination( # Setup response with nextForwardToken mock_client.get_log_events.return_value = { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - {"timestamp": 1696586513235, "message": "V29ybGQ="}, + {"timestamp": 1696586513234, "message": "Hello"}, + {"timestamp": 1696586513235, "message": "World"}, ], "nextBackwardToken": "bwd", "nextForwardToken": "fwd123", @@ -1327,8 +1300,8 @@ async def test_next_token_descending_pagination( # Setup response with nextBackwardToken mock_client.get_log_events.return_value = { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - {"timestamp": 1696586513235, "message": "V29ybGQ="}, + {"timestamp": 1696586513234, "message": "Hello"}, + {"timestamp": 1696586513235, "message": "World"}, ], "nextBackwardToken": "bwd456", "nextForwardToken": "fwd", @@ -1340,8 +1313,8 @@ async def test_next_token_descending_pagination( assert len(result.logs) == 2 # Events should be reversed for descending order - assert result.logs[0].message == "V29ybGQ=" - assert result.logs[1].message == "SGVsbG8=" + assert result.logs[0].message == "World" + assert result.logs[1].message == "Hello" assert result.next_token == "bwd456" # Should return nextBackwardToken # Verify API was called with correct parameters @@ -1363,7 +1336,7 @@ async def test_next_token_provided_in_request( """Test that provided next_token is passed to CloudWatch API""" mock_client.get_log_events.return_value = { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513234, "message": "Hello"}, ], "nextBackwardToken": "bwd", "nextForwardToken": "new_fwd", @@ -1419,7 +1392,7 @@ async def test_next_token_with_time_filtering( """Test next_token behavior with time filtering""" mock_client.get_log_events.return_value = { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513234, "message": "Hello"}, ], "nextBackwardToken": "bwd_with_time", "nextForwardToken": "fwd_with_time", @@ -1457,7 +1430,7 @@ async def test_next_token_missing_in_cloudwatch_response( """Test behavior when CloudWatch doesn't return next tokens""" mock_client.get_log_events.return_value = { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513234, "message": "Hello"}, ], # No nextBackwardToken or nextForwardToken in response } @@ -1479,7 +1452,7 @@ async def test_next_token_empty_string_in_cloudwatch_response( """Test behavior when CloudWatch returns empty string tokens""" mock_client.get_log_events.return_value = { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513234, "message": "Hello"}, ], "nextBackwardToken": "", "nextForwardToken": "", @@ -1504,8 +1477,8 @@ async def test_next_token_pagination_workflow( mock_client.get_log_events.side_effect = [ { "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - {"timestamp": 1696586513235, "message": "V29ybGQ="}, + {"timestamp": 1696586513234, "message": "Hello"}, + {"timestamp": 1696586513235, "message": "World"}, ], "nextBackwardToken": "bwd", "nextForwardToken": "token_page2", @@ -1513,7 +1486,7 @@ async def test_next_token_pagination_workflow( # Second call - returns final logs without next_token { "events": [ - {"timestamp": 1696586513236, "message": "IQ=="}, + {"timestamp": 1696586513236, "message": "!"}, ], "nextBackwardToken": "final_bwd", "nextForwardToken": "final_fwd", @@ -1526,8 +1499,8 @@ async def test_next_token_pagination_workflow( page1 = log_storage.poll_logs(project, poll_logs_request) assert len(page1.logs) == 2 - assert page1.logs[0].message == "SGVsbG8=" - assert page1.logs[1].message == "V29ybGQ=" + assert page1.logs[0].message == "Hello" + assert page1.logs[1].message == "World" assert page1.next_token == "token_page2" # Second page using next_token @@ -1535,7 +1508,7 @@ async def test_next_token_pagination_workflow( page2 = log_storage.poll_logs(project, poll_logs_request) assert len(page2.logs) == 1 - assert page2.logs[0].message == "IQ==" + assert page2.logs[0].message == "!" assert page2.next_token == "final_fwd" # Verify both API calls