Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion internal/types/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ type SidecarMount struct {

// TaskAssignmentMessage is sent from server to worker when a task is available
type TaskAssignmentMessage struct {
TaskID string `json:"task_id"`
TaskID string `json:"task_id"`
// ExecutionID identifies the concrete run execution being launched. It is
// distinct from TaskID/Task.ID, which identify the logical run and can be
// reused by follow-up or handoff executions.
ExecutionID string `json:"execution_id,omitempty"`
Task *Task `json:"task"`
DockerImage string `json:"docker_image,omitempty"`
// The "sidecar image" contains the warp agent binary and a couple other dependencies.
Expand Down
5 changes: 3 additions & 2 deletions internal/worker/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ import (
// and the backend interface, so backends don't need to handle common concerns like
// resolving environment variables, choosing default images, or building base CLI args.
type TaskParams struct {
TaskID string
Task *types.Task
TaskID string
ExecutionID string
Task *types.Task

// EnvVars contains pre-resolved common environment variables (TASK_ID, Git config,
// assignment env vars). Backends append their own config-specific env vars.
Expand Down
52 changes: 35 additions & 17 deletions internal/worker/kubernetes.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ const (
kubernetesWorkerHashLabel = "oz-worker-hash"
kubernetesTaskIDLabel = "oz-task-id"
kubernetesTaskHashLabel = "oz-task-hash"
kubernetesExecutionIDLabel = "oz-execution-id"
kubernetesExecutionHashLabel = "oz-execution-hash"

// maxLogBytes caps the amount of container log data read into memory per
// container to avoid OOM when a task produces excessive output.
Expand Down Expand Up @@ -131,8 +133,9 @@ func (b *KubernetesBackend) ExecuteTask(ctx context.Context, params *TaskParams)
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonSidecarPrep, err)
}

jobName := sanitizeKubernetesJobName(params.TaskID)
jobLabels := b.baseLabels(params.TaskID)
executionID := taskExecutionID(params)
jobName := sanitizeKubernetesJobName(executionID)
jobLabels := b.baseLabels(params.TaskID, executionID)
jobAnnotations := copyStringMap(b.config.ExtraAnnotations)
pullPolicy := normalizePullPolicy(b.config.ImagePullPolicy)

Expand Down Expand Up @@ -305,7 +308,7 @@ func (b *KubernetesBackend) ExecuteTask(ctx context.Context, params *TaskParams)
}
defer jobWatcher.Stop()

podWatcher, err := b.watchTaskPods(ctx, params.TaskID)
podWatcher, err := b.watchTaskPods(ctx, executionID)
if err != nil {
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonPodWatch, fmt.Errorf("failed to watch Pods for Job %s: %w", jobName, err))
}
Expand Down Expand Up @@ -343,15 +346,15 @@ func (b *KubernetesBackend) ExecuteTask(ctx context.Context, params *TaskParams)
if !ok {
continue
}
if result := b.handleJobState(ctx, jobState, params.TaskID); result != nil {
if result := b.handleJobState(ctx, jobState, params.TaskID, executionID); result != nil {
return result.err
}

case event, ok := <-podWatcher.ResultChan():
if !ok {
// Watch closed; reopen.
podWatcher.Stop()
podWatcher, err = b.watchTaskPods(ctx, params.TaskID)
podWatcher, err = b.watchTaskPods(ctx, executionID)
if err != nil {
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonPodWatch, fmt.Errorf("failed to re-watch Pods for Job %s: %w", jobName, err))
}
Expand All @@ -360,7 +363,7 @@ func (b *KubernetesBackend) ExecuteTask(ctx context.Context, params *TaskParams)
if event.Type == watch.Error {
log.Warnf(ctx, "Pod watch error for Job %s, reopening", jobName)
podWatcher.Stop()
podWatcher, err = b.watchTaskPods(ctx, params.TaskID)
podWatcher, err = b.watchTaskPods(ctx, executionID)
if err != nil {
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonPodWatch, fmt.Errorf("failed to re-watch Pods for Job %s: %w", jobName, err))
}
Expand All @@ -387,11 +390,11 @@ func (b *KubernetesBackend) ExecuteTask(ctx context.Context, params *TaskParams)
}
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonJobWatch, fmt.Errorf("failed to get Job %s: %w", jobName, err))
}
if result := b.handleJobState(ctx, jobState, params.TaskID); result != nil {
if result := b.handleJobState(ctx, jobState, params.TaskID, executionID); result != nil {
return result.err
}

pods, err := b.listTaskPods(ctx, params.TaskID)
pods, err := b.listTaskPods(ctx, executionID)
if err != nil {
return newBackendFailure(metrics.TaskFailurePhaseBackend, metrics.TaskFailureReasonPodWatch, fmt.Errorf("failed to list task pods for Job %s: %w", jobName, err))
}
Expand Down Expand Up @@ -430,11 +433,11 @@ type jobResult struct {

// handleJobState checks whether a Job has reached a terminal state and, if so,
// returns a *jobResult. A nil return means the Job is still in progress.
func (b *KubernetesBackend) handleJobState(ctx context.Context, jobState *batchv1.Job, taskID string) *jobResult {
func (b *KubernetesBackend) handleJobState(ctx context.Context, jobState *batchv1.Job, taskID, executionID string) *jobResult {
jobName := jobState.Name
if jobComplete(jobState) {
if zerolog.GlobalLevel() <= zerolog.DebugLevel {
pods, _ := b.listTaskPods(ctx, taskID)
pods, _ := b.listTaskPods(ctx, executionID)
logs := b.collectPodLogs(ctx, pods)
if logs != "" {
log.Debugf(ctx, "Job %s output:\n%s", jobName, logs)
Expand All @@ -444,7 +447,7 @@ func (b *KubernetesBackend) handleJobState(ctx context.Context, jobState *batchv
return &jobResult{err: nil}
}
if jobFailed(jobState) {
pods, _ := b.listTaskPods(ctx, taskID)
pods, _ := b.listTaskPods(ctx, executionID)
logs := b.collectPodLogs(ctx, pods)
if logs != "" {
log.Infof(ctx, "Job %s output:\n%s", jobName, logs)
Expand All @@ -464,9 +467,9 @@ func (b *KubernetesBackend) watchJob(ctx context.Context, jobName string) (watch
})
}

func (b *KubernetesBackend) watchTaskPods(ctx context.Context, taskID string) (watch.Interface, error) {
func (b *KubernetesBackend) watchTaskPods(ctx context.Context, executionID string) (watch.Interface, error) {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In here and in listTaskPods below, do we need to be concerned about the handling existing tasks if the upgrade happens in flight?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, if the worker's restarted it'll detach from existing tasks and won't try to re-watch them, so we should be fine. Ideally it would rediscover existing tasks though

return b.clientset.CoreV1().Pods(b.config.Namespace).Watch(ctx, metav1.ListOptions{
LabelSelector: fmt.Sprintf("%s=%s", kubernetesTaskHashLabel, kubernetesLabelHash(taskID)),
LabelSelector: fmt.Sprintf("%s=%s", kubernetesExecutionHashLabel, kubernetesLabelHash(executionID)),
})
}

Expand Down Expand Up @@ -795,21 +798,26 @@ func (b *KubernetesBackend) startupPreflightError(err error) error {
return fmt.Errorf("kubernetes startup preflight failed: the kubernetes backend requires creating task Jobs with a root init container for sidecar materialization; verify service account/RBAC and Pod Security or admission policy for namespace %q: %w", b.config.Namespace, err)
}

func (b *KubernetesBackend) baseLabels(taskID string) map[string]string {
func (b *KubernetesBackend) baseLabels(taskID, executionID string) map[string]string {
if strings.TrimSpace(executionID) == "" {
executionID = taskID
}
labels := copyStringMap(b.config.ExtraLabels)
if labels == nil {
labels = make(map[string]string, 4)
labels = make(map[string]string, 6)
}
labels[kubernetesWorkerIDLabel] = sanitizeKubernetesLabelValue(b.config.WorkerID)
labels[kubernetesWorkerHashLabel] = kubernetesLabelHash(b.config.WorkerID)
labels[kubernetesTaskIDLabel] = sanitizeKubernetesLabelValue(taskID)
labels[kubernetesTaskHashLabel] = kubernetesLabelHash(taskID)
labels[kubernetesExecutionIDLabel] = sanitizeKubernetesLabelValue(executionID)
labels[kubernetesExecutionHashLabel] = kubernetesLabelHash(executionID)
return labels
}

func (b *KubernetesBackend) listTaskPods(ctx context.Context, taskID string) ([]corev1.Pod, error) {
func (b *KubernetesBackend) listTaskPods(ctx context.Context, executionID string) ([]corev1.Pod, error) {
podList, err := b.clientset.CoreV1().Pods(b.config.Namespace).List(ctx, metav1.ListOptions{
LabelSelector: fmt.Sprintf("%s=%s", kubernetesTaskHashLabel, kubernetesLabelHash(taskID)),
LabelSelector: fmt.Sprintf("%s=%s", kubernetesExecutionHashLabel, kubernetesLabelHash(executionID)),
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -1050,6 +1058,16 @@ func kubernetesLabelHash(value string) string {
return fmt.Sprintf("%x", sum[:8])
}

func taskExecutionID(params *TaskParams) string {
if params == nil {
return ""
}
if executionID := strings.TrimSpace(params.ExecutionID); executionID != "" {
return executionID
}
return params.TaskID
}

func kubernetesTaskWrapperScript() string {
return strings.Join([]string{
"if [ -f \"$OZ_ENVIRONMENT_FILE\" ]; then",
Expand Down
76 changes: 68 additions & 8 deletions internal/worker/kubernetes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,34 @@ func TestKubernetesBackendBaseLabelsIncludeStableHashes(t *testing.T) {
},
}

labels := backend.baseLabels("Task A")
labels := backend.baseLabels("Task A", "Execution A")
if labels[kubernetesWorkerIDLabel] != "worker-a" {
t.Fatalf("worker label = %q, want %q", labels[kubernetesWorkerIDLabel], "worker-a")
}
if labels[kubernetesTaskIDLabel] != "task-a" {
t.Fatalf("task label = %q, want %q", labels[kubernetesTaskIDLabel], "task-a")
}
if labels[kubernetesExecutionIDLabel] != "execution-a" {
t.Fatalf("execution label = %q, want %q", labels[kubernetesExecutionIDLabel], "execution-a")
}
if labels[kubernetesWorkerHashLabel] != kubernetesLabelHash("Worker A") {
t.Fatalf("worker hash = %q, want %q", labels[kubernetesWorkerHashLabel], kubernetesLabelHash("Worker A"))
}
if labels[kubernetesTaskHashLabel] != kubernetesLabelHash("Task A") {
t.Fatalf("task hash = %q, want %q", labels[kubernetesTaskHashLabel], kubernetesLabelHash("Task A"))
}
if labels[kubernetesExecutionHashLabel] != kubernetesLabelHash("Execution A") {
t.Fatalf("execution hash = %q, want %q", labels[kubernetesExecutionHashLabel], kubernetesLabelHash("Execution A"))
}
}

func TestTaskExecutionIDFallsBackToTaskID(t *testing.T) {
if got := taskExecutionID(&TaskParams{TaskID: "task-1", ExecutionID: " execution-1 "}); got != "execution-1" {
t.Fatalf("execution ID = %q, want %q", got, "execution-1")
}
if got := taskExecutionID(&TaskParams{TaskID: "task-1"}); got != "task-1" {
t.Fatalf("fallback execution ID = %q, want %q", got, "task-1")
}
}

func TestKubernetesSidecarMaterializationScriptMatchesExpectedShell(t *testing.T) {
Expand Down Expand Up @@ -243,7 +258,7 @@ func TestHandleJobStateDetectsCompletion(t *testing.T) {
ObjectMeta: metav1.ObjectMeta{Name: "test-job", Namespace: "agents"},
Status: batchv1.JobStatus{},
}
if result := backend.handleJobState(ctx, job, "task-1"); result != nil {
if result := backend.handleJobState(ctx, job, "task-1", "execution-1"); result != nil {
t.Fatalf("expected nil for in-progress job, got %v", result.err)
}
})
Expand All @@ -257,7 +272,7 @@ func TestHandleJobStateDetectsCompletion(t *testing.T) {
},
},
}
result := backend.handleJobState(ctx, job, "task-1")
result := backend.handleJobState(ctx, job, "task-1", "execution-1")
if result == nil {
t.Fatal("expected non-nil result for completed job")
}
Expand All @@ -275,7 +290,7 @@ func TestHandleJobStateDetectsCompletion(t *testing.T) {
},
},
}
result := backend.handleJobState(ctx, job, "task-1")
result := backend.handleJobState(ctx, job, "task-1", "execution-1")
if result == nil {
t.Fatal("expected non-nil result for failed job")
}
Expand Down Expand Up @@ -350,8 +365,8 @@ func TestWatchTaskPodsReceivesEvents(t *testing.T) {
clientset: fakeClient,
}

taskID := "task-abc"
watcher, err := backend.watchTaskPods(context.Background(), taskID)
executionID := "execution-abc"
watcher, err := backend.watchTaskPods(context.Background(), executionID)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
Expand All @@ -363,7 +378,7 @@ func TestWatchTaskPodsReceivesEvents(t *testing.T) {
Name: "task-pod",
Namespace: "agents",
Labels: map[string]string{
kubernetesTaskHashLabel: kubernetesLabelHash(taskID),
kubernetesExecutionHashLabel: kubernetesLabelHash(executionID),
},
},
Spec: corev1.PodSpec{
Expand Down Expand Up @@ -622,6 +637,7 @@ func TestExecuteTaskUsesImageVolumesForSidecars(t *testing.T) {

err := backend.ExecuteTask(context.Background(), &TaskParams{
TaskID: "task-1",
ExecutionID: "execution-1",
DockerImage: "ubuntu:22.04",
BaseArgs: []string{"run"},
Sidecars: []types.SidecarMount{
Expand All @@ -637,6 +653,24 @@ func TestExecuteTaskUsesImageVolumesForSidecars(t *testing.T) {
if createdJob == nil {
t.Fatal("expected task job to be created")
}
if createdJob.Name != sanitizeKubernetesJobName("execution-1") {
t.Fatalf("job name = %q, want %q", createdJob.Name, sanitizeKubernetesJobName("execution-1"))
}
if createdJob.Name == sanitizeKubernetesJobName("task-1") {
t.Fatalf("job name should be execution-scoped, got task-scoped name %q", createdJob.Name)
}
if createdJob.Labels[kubernetesTaskIDLabel] != "task-1" {
t.Fatalf("task label = %q, want %q", createdJob.Labels[kubernetesTaskIDLabel], "task-1")
}
if createdJob.Labels[kubernetesTaskHashLabel] != kubernetesLabelHash("task-1") {
t.Fatalf("task hash label = %q, want %q", createdJob.Labels[kubernetesTaskHashLabel], kubernetesLabelHash("task-1"))
}
if createdJob.Labels[kubernetesExecutionIDLabel] != "execution-1" {
t.Fatalf("execution label = %q, want %q", createdJob.Labels[kubernetesExecutionIDLabel], "execution-1")
}
if createdJob.Labels[kubernetesExecutionHashLabel] != kubernetesLabelHash("execution-1") {
t.Fatalf("execution hash label = %q, want %q", createdJob.Labels[kubernetesExecutionHashLabel], kubernetesLabelHash("execution-1"))
}

if len(createdJob.Spec.Template.Spec.InitContainers) != 1 {
t.Fatalf("expected only setup init container, got %d", len(createdJob.Spec.Template.Spec.InitContainers))
Expand Down Expand Up @@ -669,6 +703,16 @@ func TestExecuteTaskUsesImageVolumesForSidecars(t *testing.T) {
}

taskContainer := createdJob.Spec.Template.Spec.Containers[0]
envMap := make(map[string]string, len(taskContainer.Env))
for _, env := range taskContainer.Env {
envMap[env.Name] = env.Value
}
if envMap["OZ_RUN_ID"] != "task-1" {
t.Fatalf("OZ_RUN_ID = %q, want %q", envMap["OZ_RUN_ID"], "task-1")
}
if _, ok := envMap["OZ_EXECUTION_ID"]; ok {
t.Fatal("expected OZ_EXECUTION_ID to be omitted from task container env")
}
var taskSidecarMount *corev1.VolumeMount
for i := range taskContainer.VolumeMounts {
if taskContainer.VolumeMounts[i].Name == "sidecar-0-image" {
Expand Down Expand Up @@ -771,6 +815,15 @@ func TestExecuteTaskUsesCopyInitContainersByDefault(t *testing.T) {
if createdJob == nil {
t.Fatal("expected task job to be created")
}
if createdJob.Name != sanitizeKubernetesJobName("task-1") {
t.Fatalf("job name = %q, want %q", createdJob.Name, sanitizeKubernetesJobName("task-1"))
}
if createdJob.Labels[kubernetesExecutionIDLabel] != "task-1" {
t.Fatalf("fallback execution label = %q, want %q", createdJob.Labels[kubernetesExecutionIDLabel], "task-1")
}
if createdJob.Labels[kubernetesExecutionHashLabel] != kubernetesLabelHash("task-1") {
t.Fatalf("fallback execution hash label = %q, want %q", createdJob.Labels[kubernetesExecutionHashLabel], kubernetesLabelHash("task-1"))
}
if createdJob.Spec.TTLSecondsAfterFinished == nil || *createdJob.Spec.TTLSecondsAfterFinished != defaultJobTTLSecondsAfterFinish {
t.Fatalf("expected default ttlSecondsAfterFinished %d, got %v", defaultJobTTLSecondsAfterFinish, createdJob.Spec.TTLSecondsAfterFinished)
}
Expand Down Expand Up @@ -813,6 +866,13 @@ func TestExecuteTaskUsesCopyInitContainersByDefault(t *testing.T) {
}

taskContainer := createdJob.Spec.Template.Spec.Containers[0]
envMap := make(map[string]string, len(taskContainer.Env))
for _, env := range taskContainer.Env {
envMap[env.Name] = env.Value
}
if _, ok := envMap["OZ_EXECUTION_ID"]; ok {
t.Fatal("expected fallback OZ_EXECUTION_ID to be omitted from task container env")
}
var taskSidecarMount *corev1.VolumeMount
for i := range taskContainer.VolumeMounts {
if taskContainer.VolumeMounts[i].Name == "sidecar-0-data" {
Expand Down Expand Up @@ -910,7 +970,7 @@ func TestKubernetesBackendShutdownPreservesWorkerJobs(t *testing.T) {
ObjectMeta: metav1.ObjectMeta{
Name: "task-job",
Namespace: "agents",
Labels: backend.baseLabels("task-1"),
Labels: backend.baseLabels("task-1", "execution-1"),
},
}
fakeClient := fake.NewSimpleClientset(job)
Expand Down
1 change: 1 addition & 0 deletions internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ func (w *Worker) prepareTaskParams(assignment *types.TaskAssignmentMessage) *Tas

return &TaskParams{
TaskID: assignment.TaskID,
ExecutionID: assignment.ExecutionID,
Task: task,
EnvVars: envVars,
BaseArgs: baseArgs,
Expand Down
Loading