From 29c298d2eda60661067b1a25d6271f0af7f76b23 Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Thu, 30 Apr 2026 11:09:48 +0000 Subject: [PATCH 1/2] feat(workloads): add agent state --- cmd/runners/main.go | 6 +- internal/config/config.go | 50 ++++- internal/server/workload_activity.go | 58 ++++++ internal/server/workload_logs_test.go | 6 +- internal/server/workloads.go | 118 ++++++++--- internal/server/workloads_test.go | 202 ++++++++++++++++--- migrations/0009_add_workload_agent_state.sql | 3 + 7 files changed, 376 insertions(+), 67 deletions(-) create mode 100644 internal/server/workload_activity.go create mode 100644 migrations/0009_add_workload_agent_state.sql diff --git a/cmd/runners/main.go b/cmd/runners/main.go index b4573e8..790b580 100644 --- a/cmd/runners/main.go +++ b/cmd/runners/main.go @@ -93,7 +93,7 @@ func run() error { defer notificationsConn.Close() grpcServer := grpc.NewServer() - runnersv1.RegisterRunnersServiceServer(grpcServer, server.New(server.Options{ + srv := server.New(server.Options{ Pool: pool, IdentityClient: identityv1.NewIdentityServiceClient(identityConn), AuthorizationClient: authorizationv1.NewAuthorizationServiceClient(authorizationConn), @@ -101,7 +101,9 @@ func run() error { ZitiManagementClient: zitiMgmtClient, NotificationsClient: notificationsv1.NewNotificationsServiceClient(notificationsConn), ZitiDialer: zitiManager, - })) + }) + runnersv1.RegisterRunnersServiceServer(grpcServer, srv) + go srv.RunWorkloadActivitySweep(ctx, cfg.WorkloadActivitySweepInterval, cfg.WorkloadKeepaliveGrace) listener, err := net.Listen("tcp", cfg.GRPCAddr) if err != nil { diff --git a/internal/config/config.go b/internal/config/config.go index 42d38ff..831b540 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,19 +14,23 @@ const ( defaultZitiManagementAddress = "ziti-management:50051" defaultNotificationsAddress = "notifications:50051" defaultGRPCAddr = ":50051" + defaultWorkloadActivitySweep = 5 * time.Second + defaultKeepaliveGrace = 25 * time.Second ) // Config captures runtime configuration derived from the environment. type Config struct { - DatabaseURL string - IdentityAddress string - AuthorizationAddress string - AgentsAddress string - ZitiManagementAddress string - NotificationsAddress string - ZitiLeaseRenewalInterval time.Duration - ZitiEnrollmentTimeout time.Duration - GRPCAddr string + DatabaseURL string + IdentityAddress string + AuthorizationAddress string + AgentsAddress string + ZitiManagementAddress string + NotificationsAddress string + ZitiLeaseRenewalInterval time.Duration + ZitiEnrollmentTimeout time.Duration + WorkloadActivitySweepInterval time.Duration + WorkloadKeepaliveGrace time.Duration + GRPCAddr string } // Load reads configuration from environment variables, applying defaults when @@ -71,6 +75,34 @@ func Load() (Config, error) { if cfg.ZitiEnrollmentTimeout <= 0 { return Config{}, fmt.Errorf("ZITI_ENROLLMENT_TIMEOUT must be greater than 0") } + + activitySweepInterval := strings.TrimSpace(os.Getenv("WORKLOAD_ACTIVITY_SWEEP_INTERVAL")) + if activitySweepInterval == "" { + cfg.WorkloadActivitySweepInterval = defaultWorkloadActivitySweep + } else { + parsed, err := time.ParseDuration(activitySweepInterval) + if err != nil { + return Config{}, fmt.Errorf("parse WORKLOAD_ACTIVITY_SWEEP_INTERVAL: %w", err) + } + cfg.WorkloadActivitySweepInterval = parsed + } + if cfg.WorkloadActivitySweepInterval <= 0 { + return Config{}, fmt.Errorf("WORKLOAD_ACTIVITY_SWEEP_INTERVAL must be greater than 0") + } + + keepaliveGrace := strings.TrimSpace(os.Getenv("KEEPALIVE_GRACE")) + if keepaliveGrace == "" { + cfg.WorkloadKeepaliveGrace = defaultKeepaliveGrace + } else { + parsed, err := time.ParseDuration(keepaliveGrace) + if err != nil { + return Config{}, fmt.Errorf("parse KEEPALIVE_GRACE: %w", err) + } + cfg.WorkloadKeepaliveGrace = parsed + } + if cfg.WorkloadKeepaliveGrace <= 0 { + return Config{}, fmt.Errorf("KEEPALIVE_GRACE must be greater than 0") + } cfg.GRPCAddr = readEnv("GRPC_ADDR", defaultGRPCAddr) return cfg, nil diff --git a/internal/server/workload_activity.go b/internal/server/workload_activity.go new file mode 100644 index 0000000..efb5cc3 --- /dev/null +++ b/internal/server/workload_activity.go @@ -0,0 +1,58 @@ +package server + +import ( + "context" + "fmt" + "log" + "time" +) + +func (s *Server) RunWorkloadActivitySweep(ctx context.Context, interval, keepaliveGrace time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.sweepWorkloadActivity(ctx, time.Now().UTC(), keepaliveGrace); err != nil { + log.Printf("runners: workload activity sweep: %v", err) + } + } + } +} + +func (s *Server) sweepWorkloadActivity(ctx context.Context, now time.Time, keepaliveGrace time.Duration) error { + cutoff := now.Add(-keepaliveGrace) + workloads, err := s.setIdleWorkloads(ctx, cutoff) + if err != nil { + return err + } + for _, workload := range workloads { + s.publishWorkloadUpdateNotifications(ctx, workload, false, false, false, true) + } + return nil +} + +func (s *Server) setIdleWorkloads(ctx context.Context, cutoff time.Time) ([]workloadRecord, error) { + query := fmt.Sprintf("UPDATE workloads SET agent_state = $1, updated_at = NOW() WHERE status = $2 AND agent_state = $3 AND last_activity_at < $4 AND removed_at IS NULL RETURNING %s", workloadColumns) + rows, err := s.pool.Query(ctx, query, workloadAgentStateIdle, workloadStatusRunning, workloadAgentStateProcessing, cutoff) + if err != nil { + return nil, err + } + defer rows.Close() + + workloads := []workloadRecord{} + for rows.Next() { + workload, err := scanWorkload(rows) + if err != nil { + return nil, err + } + workloads = append(workloads, workload) + } + if err := rows.Err(); err != nil { + return nil, err + } + return workloads, nil +} diff --git a/internal/server/workload_logs_test.go b/internal/server/workload_logs_test.go index ade7b2a..1036bdc 100644 --- a/internal/server/workload_logs_test.go +++ b/internal/server/workload_logs_test.go @@ -131,7 +131,7 @@ func TestStreamWorkloadLogsProxiesRunnerStream(t *testing.T) { } workloadRows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), instanceID, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), instanceID, now, nil, nil, now, now) workloadQuery := fmt.Sprintf("SELECT %s FROM workloads WHERE id = $1", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(workloadQuery)). WithArgs(workloadID). @@ -238,7 +238,7 @@ func TestStreamWorkloadLogsRequiresViewWorkloads(t *testing.T) { containersJSON := []byte("[]") workloadRows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), "instance-1", now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), "instance-1", now, nil, nil, now, now) workloadQuery := fmt.Sprintf("SELECT %s FROM workloads WHERE id = $1", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(workloadQuery)). WithArgs(workloadID). @@ -295,7 +295,7 @@ func TestStreamWorkloadLogsMissingInstanceID(t *testing.T) { containersJSON := []byte("[]") workloadRows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) workloadQuery := fmt.Sprintf("SELECT %s FROM workloads WHERE id = $1", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(workloadQuery)). WithArgs(workloadID). diff --git a/internal/server/workloads.go b/internal/server/workloads.go index 8e109d5..51b9518 100644 --- a/internal/server/workloads.go +++ b/internal/server/workloads.go @@ -30,6 +30,9 @@ const ( workloadStatusStopped = "stopped" workloadStatusFailed = "failed" + workloadAgentStateProcessing = "processing" + workloadAgentStateIdle = "idle" + workloadFailureReasonStartFailed = "start_failed" workloadFailureReasonImagePullFailed = "image_pull_failed" workloadFailureReasonConfigInvalid = "config_invalid" @@ -44,7 +47,7 @@ const ( containerStatusTerminated = "terminated" containerStatusWaiting = "waiting" - workloadColumns = `id, runner_id, thread_id, agent_id, organization_id, status, failure_reason, failure_message, containers, ziti_identity_id, allocated_cpu_millicores, allocated_ram_bytes, instance_id, last_activity_at, last_metering_sampled_at, removed_at, created_at, updated_at` + workloadColumns = `id, runner_id, thread_id, agent_id, organization_id, status, agent_state, failure_reason, failure_message, containers, ziti_identity_id, allocated_cpu_millicores, allocated_ram_bytes, instance_id, last_activity_at, last_metering_sampled_at, removed_at, created_at, updated_at` ) type workloadRecord struct { @@ -54,6 +57,7 @@ type workloadRecord struct { AgentID uuid.UUID OrganizationID uuid.UUID Status string + AgentState string FailureReason *string FailureMessage *string Containers []containerRecord @@ -80,14 +84,15 @@ type workloadInsertInput struct { } type workloadUpdateInput struct { - ID uuid.UUID - Status *string - FailureReason *string - FailureMessage *string - ContainersJSON *[]byte - InstanceID *string - RemovedAt *time.Time - LastMeteringAt *time.Time + ID uuid.UUID + Status *string + FailureReason *string + FailureMessage *string + ContainersJSON *[]byte + InstanceID *string + RemovedAt *time.Time + LastMeteringAt *time.Time + ResetLastActivity bool } type workloadListFilter struct { @@ -263,8 +268,16 @@ func (s *Server) UpdateWorkload(ctx context.Context, req *runnersv1.UpdateWorklo return nil, status.Error(codes.InvalidArgument, "at least one field must be provided") } - var existingWorkload *workloadRecord + needsExisting := false if s.notificationsClient != nil && (statusValue != nil || containersJSON != nil || failureReason != nil || failureMessage != nil) { + needsExisting = true + } + if statusValue != nil && *statusValue == workloadStatusRunning { + needsExisting = true + } + + var existingWorkload *workloadRecord + if needsExisting { workload, err := s.getWorkloadByID(ctx, id) if err != nil { return nil, toStatusError(err) @@ -272,15 +285,18 @@ func (s *Server) UpdateWorkload(ctx context.Context, req *runnersv1.UpdateWorklo existingWorkload = &workload } + resetLastActivity := existingWorkload != nil && statusValue != nil && *statusValue == workloadStatusRunning && existingWorkload.Status == workloadStatusStarting + workload, err := s.updateWorkload(ctx, workloadUpdateInput{ - ID: id, - Status: statusValue, - FailureReason: failureReason, - FailureMessage: failureMessage, - ContainersJSON: containersJSON, - InstanceID: instanceID, - RemovedAt: removedAt, - LastMeteringAt: lastMeteringAt, + ID: id, + Status: statusValue, + FailureReason: failureReason, + FailureMessage: failureMessage, + ContainersJSON: containersJSON, + InstanceID: instanceID, + RemovedAt: removedAt, + LastMeteringAt: lastMeteringAt, + ResetLastActivity: resetLastActivity, }) if err != nil { return nil, toStatusError(err) @@ -289,8 +305,9 @@ func (s *Server) UpdateWorkload(ctx context.Context, req *runnersv1.UpdateWorklo containersChanged := existingWorkload != nil && containersJSON != nil && !containersEqualByName(existingWorkload.Containers, containerRecords) failureReasonChanged := existingWorkload != nil && failureReason != nil && (existingWorkload.FailureReason == nil || *existingWorkload.FailureReason != *failureReason) failureMessageChanged := existingWorkload != nil && failureMessage != nil && (existingWorkload.FailureMessage == nil || *existingWorkload.FailureMessage != *failureMessage) - if statusChanged || containersChanged || failureReasonChanged || failureMessageChanged { - s.publishWorkloadUpdateNotifications(ctx, workload, statusChanged, containersChanged, failureReasonChanged || failureMessageChanged) + agentStateChanged := existingWorkload != nil && workload.AgentState != existingWorkload.AgentState + if statusChanged || containersChanged || failureReasonChanged || failureMessageChanged || agentStateChanged { + s.publishWorkloadUpdateNotifications(ctx, workload, statusChanged, containersChanged, failureReasonChanged || failureMessageChanged, agentStateChanged) } protoWorkload, err := toProtoWorkload(workload) if err != nil { @@ -299,13 +316,14 @@ func (s *Server) UpdateWorkload(ctx context.Context, req *runnersv1.UpdateWorklo return &runnersv1.UpdateWorkloadResponse{Workload: protoWorkload}, nil } -func (s *Server) publishWorkloadUpdateNotifications(ctx context.Context, workload workloadRecord, statusChanged, containersChanged, failureChanged bool) { +func (s *Server) publishWorkloadUpdateNotifications(ctx context.Context, workload workloadRecord, statusChanged, containersChanged, failureChanged, agentStateChanged bool) { if s.notificationsClient == nil { return } payloadFields := map[string]any{ "workload_id": workload.Meta.ID.String(), "status": workload.Status, + "agent_state": workload.AgentState, } if workload.FailureReason != nil { payloadFields["failure_reason"] = *workload.FailureReason @@ -323,7 +341,7 @@ func (s *Server) publishWorkloadUpdateNotifications(ctx context.Context, workloa fmt.Sprintf("organization:%s", workload.OrganizationID.String()), workloadRoom, } - if statusChanged || containersChanged || failureChanged { + if statusChanged || containersChanged || failureChanged || agentStateChanged { s.publishWorkloadNotification(ctx, "workload.updated", updatedRooms, payload) } if statusChanged { @@ -388,9 +406,13 @@ func (s *Server) TouchWorkload(ctx context.Context, req *runnersv1.TouchWorkload if err != nil { return nil, status.Errorf(codes.InvalidArgument, "id: %v", err) } - if err := s.touchWorkloadForAgent(ctx, id, callerID); err != nil { + workload, err := s.touchWorkloadForAgent(ctx, id, callerID) + if err != nil { return nil, toStatusError(err) } + if workload != nil { + s.publishWorkloadUpdateNotifications(ctx, *workload, false, false, false, true) + } return &runnersv1.TouchWorkloadResponse{}, nil } @@ -666,6 +688,9 @@ func (s *Server) updateWorkload(ctx context.Context, input workloadUpdateInput) if input.LastMeteringAt != nil { addUpdateClause(&clauses, &args, "last_metering_sampled_at", *input.LastMeteringAt) } + if input.ResetLastActivity { + clauses = append(clauses, "last_activity_at = NOW()") + } query, args := buildUpdateQuery("workloads", workloadColumns, clauses, args, input.ID) row := s.pool.QueryRow(ctx, query, args...) workload, err := scanWorkload(row) @@ -711,19 +736,28 @@ func (s *Server) softDeleteWorkload(ctx context.Context, id uuid.UUID) error { return nil } -func (s *Server) touchWorkloadForAgent(ctx context.Context, id uuid.UUID, agentID uuid.UUID) error { +func (s *Server) touchWorkloadForAgent(ctx context.Context, id uuid.UUID, agentID uuid.UUID) (*workloadRecord, error) { + query := fmt.Sprintf("UPDATE workloads SET agent_state = $1, last_activity_at = NOW(), updated_at = NOW() WHERE id = $2 AND agent_id = $3 AND agent_state = $4 RETURNING %s", workloadColumns) + row := s.pool.QueryRow(ctx, query, workloadAgentStateProcessing, id, agentID, workloadAgentStateIdle) + workload, err := scanWorkload(row) + if err == nil { + return &workload, nil + } + if !errors.Is(err, pgx.ErrNoRows) { + return nil, err + } result, err := s.pool.Exec(ctx, `UPDATE workloads SET last_activity_at = NOW(), updated_at = NOW() WHERE id = $1 AND agent_id = $2`, id, agentID) if err != nil { - return err + return nil, err } if result.RowsAffected() == 0 { _, err := s.getWorkloadByID(ctx, id) if err != nil { - return err + return nil, err } - return PermissionDenied() + return nil, PermissionDenied() } - return nil + return nil, nil } func (s *Server) listWorkloadsByThread(ctx context.Context, threadID uuid.UUID, agentID *uuid.UUID, statuses []string, pageSize int32, pageToken string) ([]workloadRecord, string, error) { @@ -1191,6 +1225,7 @@ func scanWorkload(row pgx.Row) (workloadRecord, error) { &workload.AgentID, &workload.OrganizationID, &workload.Status, + &workload.AgentState, &failureReason, &failureMessage, &containersData, @@ -1240,6 +1275,10 @@ func toProtoWorkload(record workloadRecord) (*runnersv1.Workload, error) { if err != nil { return nil, err } + agentStateValue, err := workloadAgentStateFromString(record.AgentState) + if err != nil { + return nil, err + } containers, err := containersToProto(record.Containers) if err != nil { return nil, err @@ -1251,6 +1290,7 @@ func toProtoWorkload(record workloadRecord) (*runnersv1.Workload, error) { AgentId: record.AgentID.String(), OrganizationId: record.OrganizationID.String(), Status: statusValue, + AgentState: agentStateValue, Containers: containers, ZitiIdentityId: record.ZitiIdentityID, LastActivityAt: timestamppb.New(record.LastActivityAt), @@ -1493,6 +1533,17 @@ func workloadStatusToString(status runnersv1.WorkloadStatus) (string, error) { } } +func workloadAgentStateToString(state runnersv1.WorkloadAgentState) (string, error) { + switch state { + case runnersv1.WorkloadAgentState_WORKLOAD_AGENT_STATE_PROCESSING: + return workloadAgentStateProcessing, nil + case runnersv1.WorkloadAgentState_WORKLOAD_AGENT_STATE_IDLE: + return workloadAgentStateIdle, nil + default: + return "", fmt.Errorf("invalid workload agent state: %s", state.String()) + } +} + func workloadFailureReasonToString(reason runnersv1.WorkloadFailureReason) (string, error) { switch reason { case runnersv1.WorkloadFailureReason_WORKLOAD_FAILURE_REASON_START_FAILED: @@ -1527,6 +1578,17 @@ func workloadStatusFromString(value string) (runnersv1.WorkloadStatus, error) { } } +func workloadAgentStateFromString(value string) (runnersv1.WorkloadAgentState, error) { + switch value { + case workloadAgentStateProcessing: + return runnersv1.WorkloadAgentState_WORKLOAD_AGENT_STATE_PROCESSING, nil + case workloadAgentStateIdle: + return runnersv1.WorkloadAgentState_WORKLOAD_AGENT_STATE_IDLE, nil + default: + return runnersv1.WorkloadAgentState_WORKLOAD_AGENT_STATE_UNSPECIFIED, fmt.Errorf("invalid workload agent state: %s", value) + } +} + func workloadFailureReasonFromString(value string) (runnersv1.WorkloadFailureReason, error) { switch value { case workloadFailureReasonStartFailed: diff --git a/internal/server/workloads_test.go b/internal/server/workloads_test.go index 74fe875..d86f111 100644 --- a/internal/server/workloads_test.go +++ b/internal/server/workloads_test.go @@ -31,6 +31,7 @@ var workloadRowColumns = []string{ "agent_id", "organization_id", "status", + "agent_state", "failure_reason", "failure_message", "containers", @@ -76,7 +77,7 @@ func TestListWorkloadsFiltersOrganization(t *testing.T) { containersJSON := []byte("[]") rows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) query := fmt.Sprintf("SELECT %s FROM workloads WHERE workloads.organization_id = $1 ORDER BY workloads.created_at DESC, workloads.id ASC LIMIT $2", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(query)). @@ -152,7 +153,7 @@ func TestListWorkloadsInternalNoIdentity(t *testing.T) { containersJSON := []byte("[]") rows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) limit := normalizePageSize(0) query := fmt.Sprintf("SELECT %s FROM workloads ORDER BY workloads.created_at DESC, workloads.id ASC LIMIT $1", workloadColumns) @@ -221,7 +222,7 @@ func TestListWorkloadsFiltersRunner(t *testing.T) { containersJSON := []byte("[]") rows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) query := fmt.Sprintf("SELECT %s FROM workloads WHERE workloads.organization_id = $1 AND workloads.runner_id = ANY($2) ORDER BY workloads.created_at DESC, workloads.id ASC LIMIT $3", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(query)). @@ -299,7 +300,7 @@ func TestListWorkloadsPendingSample(t *testing.T) { containersJSON := []byte("[]") rows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) query := fmt.Sprintf("SELECT %s FROM workloads WHERE workloads.organization_id = $1 AND %s ORDER BY workloads.created_at DESC, workloads.id ASC LIMIT $2", workloadColumns, pendingSampleClause) mockPool.ExpectQuery(regexp.QuoteMeta(query)). @@ -368,7 +369,7 @@ func TestListWorkloadsCursorPagination(t *testing.T) { } rows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) pageSize := int32(2) limit := normalizePageSize(pageSize) @@ -430,7 +431,7 @@ func TestListWorkloadsSortByAgentQuery(t *testing.T) { sortExpr := "CASE workloads.agent_id WHEN $2 THEN $3 END" query := fmt.Sprintf("SELECT %s FROM workloads WHERE workloads.organization_id = $1 AND (%s > $4 OR (%s = $4 AND workloads.id > $5)) ORDER BY %s ASC, workloads.id ASC LIMIT $6", workloadColumns, sortExpr, sortExpr, sortExpr) rows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) mockPool.ExpectQuery(regexp.QuoteMeta(query)). WithArgs(organizationID, agentID, primary, primary, cursorID, int(limit)+1). WillReturnRows(rows) @@ -486,7 +487,7 @@ func TestListWorkloadsSortByRunnerQuery(t *testing.T) { sortColumn := "LOWER(runners.name)" query := fmt.Sprintf("SELECT %s FROM workloads JOIN runners ON workloads.runner_id = runners.id WHERE workloads.organization_id = $1 AND (%s < $2 OR (%s = $2 AND workloads.id > $3)) ORDER BY %s DESC, workloads.id ASC LIMIT $4", workloadColumns, sortColumn, sortColumn, sortColumn) rows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) mockPool.ExpectQuery(regexp.QuoteMeta(query)). WithArgs(organizationID, primary, cursorID, int(limit)+1). WillReturnRows(rows) @@ -637,7 +638,7 @@ func TestListWorkloadsByThreadFilters(t *testing.T) { limit := normalizePageSize(pageSize) rows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) query := fmt.Sprintf("SELECT %s FROM workloads WHERE thread_id = $1 AND agent_id = $2 AND status = ANY($3) ORDER BY created_at DESC, id DESC LIMIT $4", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(query)). @@ -677,7 +678,7 @@ func TestListWorkloadsByThreadInternalNoIdentity(t *testing.T) { limit := normalizePageSize(0) rows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) query := fmt.Sprintf("SELECT %s FROM workloads WHERE thread_id = $1 ORDER BY created_at DESC, id DESC LIMIT $2", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(query)). @@ -738,9 +739,9 @@ func TestListWorkloadsByThreadPagination(t *testing.T) { thirdID := uuid.New() rows := pgxmock.NewRows(workloadRowColumns). - AddRow(firstID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, firstAt, nil, nil, firstAt, firstAt). - AddRow(secondID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, secondAt, nil, nil, secondAt, secondAt). - AddRow(thirdID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, thirdAt, nil, nil, thirdAt, thirdAt) + AddRow(firstID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, firstAt, nil, nil, firstAt, firstAt). + AddRow(secondID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, secondAt, nil, nil, secondAt, secondAt). + AddRow(thirdID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, thirdAt, nil, nil, thirdAt, thirdAt) query := fmt.Sprintf("SELECT %s FROM workloads WHERE thread_id = $1 AND (created_at < $2 OR (created_at = $2 AND id < $3)) ORDER BY created_at DESC, id DESC LIMIT $4", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(query)). @@ -784,8 +785,8 @@ func TestListWorkloadsByThreadPaginationTieBreak(t *testing.T) { secondID := uuid.MustParse("00000000-0000-0000-0000-000000000000") rows := pgxmock.NewRows(workloadRowColumns). - AddRow(firstID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, createdAt, nil, nil, createdAt, createdAt). - AddRow(secondID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, createdAt, nil, nil, createdAt, createdAt) + AddRow(firstID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, createdAt, nil, nil, createdAt, createdAt). + AddRow(secondID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, createdAt, nil, nil, createdAt, createdAt) query := fmt.Sprintf("SELECT %s FROM workloads WHERE thread_id = $1 ORDER BY created_at DESC, id DESC LIMIT $2", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(query)). @@ -839,7 +840,7 @@ func TestGetWorkloadRequiresViewWorkloads(t *testing.T) { containersJSON := []byte("[]") rows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) query := fmt.Sprintf("SELECT %s FROM workloads WHERE id = $1", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(workloadID).WillReturnRows(rows) @@ -880,6 +881,11 @@ func TestTouchWorkload(t *testing.T) { agentID := uuid.New() callerID := agentID + updateQuery := fmt.Sprintf("UPDATE workloads SET agent_state = $1, last_activity_at = NOW(), updated_at = NOW() WHERE id = $2 AND agent_id = $3 AND agent_state = $4 RETURNING %s", workloadColumns) + mockPool.ExpectQuery(regexp.QuoteMeta(updateQuery)). + WithArgs(workloadAgentStateProcessing, workloadID, callerID, workloadAgentStateIdle). + WillReturnRows(pgxmock.NewRows(workloadRowColumns)) + query := "UPDATE workloads SET last_activity_at = NOW(), updated_at = NOW() WHERE id = $1 AND agent_id = $2" mockPool.ExpectExec(regexp.QuoteMeta(query)). WithArgs(workloadID, callerID). @@ -897,6 +903,70 @@ func TestTouchWorkload(t *testing.T) { } } +func TestTouchWorkloadPublishesUpdateWhenIdle(t *testing.T) { + mockPool, err := pgxmock.NewPool() + if err != nil { + t.Fatalf("failed to create mock pool: %v", err) + } + + workloadID := uuid.New() + runnerID := uuid.New() + threadID := uuid.New() + agentID := uuid.New() + organizationID := uuid.New() + now := time.Now().UTC() + containersJSON := []byte("[]") + + updateQuery := fmt.Sprintf("UPDATE workloads SET agent_state = $1, last_activity_at = NOW(), updated_at = NOW() WHERE id = $2 AND agent_id = $3 AND agent_state = $4 RETURNING %s", workloadColumns) + rows := pgxmock.NewRows(workloadRowColumns). + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + mockPool.ExpectQuery(regexp.QuoteMeta(updateQuery)). + WithArgs(workloadAgentStateProcessing, workloadID, agentID, workloadAgentStateIdle). + WillReturnRows(rows) + + published := make([]*notificationsv1.PublishRequest, 0, 1) + notificationsClient := fakeNotificationsClient{publish: func(ctx context.Context, req *notificationsv1.PublishRequest) (*notificationsv1.PublishResponse, error) { + published = append(published, req) + return ¬ificationsv1.PublishResponse{}, nil + }} + + srv := New(Options{Pool: mockPool, NotificationsClient: notificationsClient}) + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(identityMetadata, agentID.String())) + _, err = srv.TouchWorkload(ctx, &runnersv1.TouchWorkloadRequest{Id: workloadID.String()}) + if err != nil { + t.Fatalf("TouchWorkload failed: %v", err) + } + if len(published) != 1 { + t.Fatalf("expected 1 notification, got %d", len(published)) + } + request := published[0] + if request.GetEvent() != "workload.updated" { + t.Fatalf("unexpected event: %s", request.GetEvent()) + } + workloadRoom := fmt.Sprintf("workload:%s", workloadID) + orgRoom := fmt.Sprintf("organization:%s", organizationID) + rooms := request.GetRooms() + if len(rooms) != 2 || !hasRoom(rooms, workloadRoom) || !hasRoom(rooms, orgRoom) { + t.Fatalf("unexpected workload.updated rooms: %v", rooms) + } + payload := request.GetPayload().AsMap() + if payload["workload_id"] != workloadID.String() { + t.Fatalf("unexpected workload_id payload: %v", payload["workload_id"]) + } + statusValue, ok := payload["status"].(string) + if !ok || statusValue != workloadStatusRunning { + t.Fatalf("unexpected status payload: %v", payload["status"]) + } + agentStateValue, ok := payload["agent_state"].(string) + if !ok || agentStateValue != workloadAgentStateProcessing { + t.Fatalf("unexpected agent_state payload: %v", payload["agent_state"]) + } + + if err := mockPool.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + func TestTouchWorkloadRequiresAgentIdentity(t *testing.T) { mockPool, err := pgxmock.NewPool() if err != nil { @@ -912,6 +982,11 @@ func TestTouchWorkloadRequiresAgentIdentity(t *testing.T) { now := time.Now().UTC() containersJSON := []byte("[]") + updateQuery := fmt.Sprintf("UPDATE workloads SET agent_state = $1, last_activity_at = NOW(), updated_at = NOW() WHERE id = $2 AND agent_id = $3 AND agent_state = $4 RETURNING %s", workloadColumns) + mockPool.ExpectQuery(regexp.QuoteMeta(updateQuery)). + WithArgs(workloadAgentStateProcessing, workloadID, callerID, workloadAgentStateIdle). + WillReturnRows(pgxmock.NewRows(workloadRowColumns)) + query := "UPDATE workloads SET last_activity_at = NOW(), updated_at = NOW() WHERE id = $1 AND agent_id = $2" mockPool.ExpectExec(regexp.QuoteMeta(query)). WithArgs(workloadID, callerID). @@ -919,7 +994,7 @@ func TestTouchWorkloadRequiresAgentIdentity(t *testing.T) { getQuery := fmt.Sprintf(`SELECT %s FROM workloads WHERE id = $1`, workloadColumns) rows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) mockPool.ExpectQuery(regexp.QuoteMeta(getQuery)).WithArgs(workloadID).WillReturnRows(rows) srv := New(Options{Pool: mockPool}) @@ -934,6 +1009,76 @@ func TestTouchWorkloadRequiresAgentIdentity(t *testing.T) { } } +func TestSweepWorkloadActivityPublishesUpdates(t *testing.T) { + mockPool, err := pgxmock.NewPool() + if err != nil { + t.Fatalf("failed to create mock pool: %v", err) + } + + firstID := uuid.New() + secondID := uuid.New() + runnerID := uuid.New() + threadID := uuid.New() + agentID := uuid.New() + organizationID := uuid.New() + now := time.Now().UTC() + keepaliveGrace := 25 * time.Second + cutoff := now.Add(-keepaliveGrace) + lastActivity := cutoff.Add(-2 * time.Second) + containersJSON := []byte("[]") + + updateQuery := fmt.Sprintf("UPDATE workloads SET agent_state = $1, updated_at = NOW() WHERE status = $2 AND agent_state = $3 AND last_activity_at < $4 AND removed_at IS NULL RETURNING %s", workloadColumns) + rows := pgxmock.NewRows(workloadRowColumns). + AddRow(firstID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateIdle, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, lastActivity, nil, nil, now, now). + AddRow(secondID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateIdle, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, lastActivity, nil, nil, now, now) + mockPool.ExpectQuery(regexp.QuoteMeta(updateQuery)). + WithArgs(workloadAgentStateIdle, workloadStatusRunning, workloadAgentStateProcessing, cutoff). + WillReturnRows(rows) + + published := []*notificationsv1.PublishRequest{} + notificationsClient := fakeNotificationsClient{publish: func(ctx context.Context, req *notificationsv1.PublishRequest) (*notificationsv1.PublishResponse, error) { + published = append(published, req) + return ¬ificationsv1.PublishResponse{}, nil + }} + + srv := New(Options{Pool: mockPool, NotificationsClient: notificationsClient}) + if err := srv.sweepWorkloadActivity(context.Background(), now, keepaliveGrace); err != nil { + t.Fatalf("sweepWorkloadActivity failed: %v", err) + } + if len(published) != 2 { + t.Fatalf("expected 2 notifications, got %d", len(published)) + } + + for _, req := range published { + if req.GetEvent() != "workload.updated" { + t.Fatalf("unexpected event: %s", req.GetEvent()) + } + payload := req.GetPayload().AsMap() + agentStateValue, ok := payload["agent_state"].(string) + if !ok || agentStateValue != workloadAgentStateIdle { + t.Fatalf("unexpected agent_state payload: %v", payload["agent_state"]) + } + statusValue, ok := payload["status"].(string) + if !ok || statusValue != workloadStatusRunning { + t.Fatalf("unexpected status payload: %v", payload["status"]) + } + workloadID, ok := payload["workload_id"].(string) + if !ok { + t.Fatalf("expected workload_id payload, got %v", payload["workload_id"]) + } + rooms := req.GetRooms() + workloadRoom := fmt.Sprintf("workload:%s", workloadID) + orgRoom := fmt.Sprintf("organization:%s", organizationID) + if len(rooms) != 2 || !hasRoom(rooms, workloadRoom) || !hasRoom(rooms, orgRoom) { + t.Fatalf("unexpected workload.updated rooms: %v", rooms) + } + } + + if err := mockPool.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + func TestUpdateWorkload(t *testing.T) { mockPool, err := pgxmock.NewPool() if err != nil { @@ -974,8 +1119,15 @@ func TestUpdateWorkload(t *testing.T) { t.Fatalf("failed to marshal containers: %v", err) } + selectRows := pgxmock.NewRows(workloadRowColumns). + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), instanceID, now, nil, nil, now, now) + selectQuery := fmt.Sprintf("SELECT %s FROM workloads WHERE id = $1", workloadColumns) + mockPool.ExpectQuery(regexp.QuoteMeta(selectQuery)). + WithArgs(workloadID). + WillReturnRows(selectRows) + rows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), instanceID, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), instanceID, now, nil, nil, now, now) query := fmt.Sprintf("UPDATE workloads SET status = $1, containers = $2, instance_id = $3, updated_at = NOW() WHERE id = $4 RETURNING %s", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(query)). @@ -1016,7 +1168,7 @@ func TestUpdateWorkloadPublishesNotifications(t *testing.T) { containersJSON := []byte("[]") selectRows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusStarting, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusStarting, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) selectQuery := fmt.Sprintf("SELECT %s FROM workloads WHERE id = $1", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(selectQuery)). @@ -1051,9 +1203,9 @@ func TestUpdateWorkloadPublishesNotifications(t *testing.T) { } updateRows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, updatedContainersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, updatedContainersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) - updateQuery := fmt.Sprintf("UPDATE workloads SET status = $1, containers = $2, updated_at = NOW() WHERE id = $3 RETURNING %s", workloadColumns) + updateQuery := fmt.Sprintf("UPDATE workloads SET status = $1, containers = $2, last_activity_at = NOW(), updated_at = NOW() WHERE id = $3 RETURNING %s", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(updateQuery)). WithArgs(workloadStatusRunning, updatedContainersJSON, workloadID). WillReturnRows(updateRows) @@ -1136,7 +1288,7 @@ func TestUpdateWorkloadFailureMetadata(t *testing.T) { failureMessage := "back-off" selectRows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) selectQuery := fmt.Sprintf("SELECT %s FROM workloads WHERE id = $1", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(selectQuery)). @@ -1144,7 +1296,7 @@ func TestUpdateWorkloadFailureMetadata(t *testing.T) { WillReturnRows(selectRows) updateRows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, failureReason, failureMessage, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, failureReason, failureMessage, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) updateQuery := fmt.Sprintf("UPDATE workloads SET failure_reason = $1, failure_message = $2, updated_at = NOW() WHERE id = $3 RETURNING %s", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(updateQuery)). @@ -1255,7 +1407,7 @@ func TestUpdateWorkloadSkipsNotificationsWhenContainersUnchanged(t *testing.T) { } selectRows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, existingContainersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, existingContainersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) selectQuery := fmt.Sprintf("SELECT %s FROM workloads WHERE id = $1", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(selectQuery)). @@ -1273,7 +1425,7 @@ func TestUpdateWorkloadSkipsNotificationsWhenContainersUnchanged(t *testing.T) { } updateRows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, requestContainersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateProcessing, nil, nil, requestContainersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) updateQuery := fmt.Sprintf("UPDATE workloads SET containers = $1, updated_at = NOW() WHERE id = $2 RETURNING %s", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(updateQuery)). @@ -1370,7 +1522,7 @@ func TestSoftDeleteWorkload(t *testing.T) { containersJSON := []byte("[]") rows := pgxmock.NewRows(workloadRowColumns). - AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusStopped, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, now, now, now) + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusStopped, workloadAgentStateProcessing, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, now, now, now) query := fmt.Sprintf("UPDATE workloads SET status = $1, removed_at = NOW(), updated_at = NOW() WHERE id = $2 RETURNING %s", workloadColumns) mockPool.ExpectQuery(regexp.QuoteMeta(query)). diff --git a/migrations/0009_add_workload_agent_state.sql b/migrations/0009_add_workload_agent_state.sql new file mode 100644 index 0000000..31a06d2 --- /dev/null +++ b/migrations/0009_add_workload_agent_state.sql @@ -0,0 +1,3 @@ +ALTER TABLE workloads +ADD COLUMN IF NOT EXISTS agent_state TEXT NOT NULL DEFAULT 'processing' + CHECK (agent_state IN ('processing', 'idle')); From f249b005f5b99d7ce8935d8b0dc4b7d06232204c Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Thu, 30 Apr 2026 11:34:49 +0000 Subject: [PATCH 2/2] fix(workloads): reset activity on status --- internal/server/workloads.go | 17 ++++- internal/server/workloads_test.go | 101 ++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 3 deletions(-) diff --git a/internal/server/workloads.go b/internal/server/workloads.go index 51b9518..82d5bbc 100644 --- a/internal/server/workloads.go +++ b/internal/server/workloads.go @@ -382,10 +382,21 @@ func (s *Server) UpdateWorkloadStatus(ctx context.Context, req *runnersv1.Update return nil, status.Errorf(codes.Internal, "marshal containers: %v", err) } + var existingWorkload *workloadRecord + if statusValue == workloadStatusRunning { + workload, err := s.getWorkloadByID(ctx, id) + if err != nil { + return nil, toStatusError(err) + } + existingWorkload = &workload + } + resetLastActivity := existingWorkload != nil && existingWorkload.Status == workloadStatusStarting + workload, err := s.updateWorkload(ctx, workloadUpdateInput{ - ID: id, - Status: &statusValue, - ContainersJSON: &containersJSON, + ID: id, + Status: &statusValue, + ContainersJSON: &containersJSON, + ResetLastActivity: resetLastActivity, }) if err != nil { return nil, toStatusError(err) diff --git a/internal/server/workloads_test.go b/internal/server/workloads_test.go index d86f111..d4780ad 100644 --- a/internal/server/workloads_test.go +++ b/internal/server/workloads_test.go @@ -704,6 +704,9 @@ func TestListWorkloadsByThreadInternalNoIdentity(t *testing.T) { if resp.GetWorkloads()[0].GetOrganizationId() != organizationID.String() { t.Fatalf("expected organization id %q, got %q", organizationID.String(), resp.GetWorkloads()[0].GetOrganizationId()) } + if resp.GetWorkloads()[0].GetAgentState() != runnersv1.WorkloadAgentState_WORKLOAD_AGENT_STATE_PROCESSING { + t.Fatalf("expected agent state %v, got %v", runnersv1.WorkloadAgentState_WORKLOAD_AGENT_STATE_PROCESSING, resp.GetWorkloads()[0].GetAgentState()) + } if checkCalls != 0 { t.Fatalf("expected no authorization checks, got %d", checkCalls) } @@ -871,6 +874,63 @@ func TestGetWorkloadRequiresViewWorkloads(t *testing.T) { } } +func TestGetWorkloadReturnsAgentState(t *testing.T) { + mockPool, err := pgxmock.NewPool() + if err != nil { + t.Fatalf("failed to create mock pool: %v", err) + } + + workloadID := uuid.New() + runnerID := uuid.New() + threadID := uuid.New() + agentID := uuid.New() + organizationID := uuid.New() + callerID := uuid.New() + now := time.Now().UTC() + containersJSON := []byte("[]") + + rows := pgxmock.NewRows(workloadRowColumns). + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, workloadAgentStateIdle, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + + query := fmt.Sprintf("SELECT %s FROM workloads WHERE id = $1", workloadColumns) + mockPool.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(workloadID).WillReturnRows(rows) + + runnerName := "runner-name" + runnerRows := pgxmock.NewRows([]string{"id", "name"}).AddRow(runnerID, runnerName) + mockPool.ExpectQuery(regexp.QuoteMeta("SELECT id, name FROM runners WHERE id = ANY($1)")). + WithArgs(pgtype.FlatArray[uuid.UUID]([]uuid.UUID{runnerID})). + WillReturnRows(runnerRows) + + agentName := "agent-name" + agentsClient := fakeAgentsClient{getAgent: func(ctx context.Context, req *agentsv1.GetAgentRequest) (*agentsv1.GetAgentResponse, error) { + return &agentsv1.GetAgentResponse{Agent: &agentsv1.Agent{Name: agentName}}, nil + }} + + authorizationClient := fakeAuthorizationClient{check: func(ctx context.Context, req *authorizationv1.CheckRequest) (*authorizationv1.CheckResponse, error) { + return &authorizationv1.CheckResponse{Allowed: true}, nil + }} + + srv := New(Options{Pool: mockPool, AuthorizationClient: authorizationClient, AgentsClient: agentsClient}) + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(identityMetadata, callerID.String())) + resp, err := srv.GetWorkload(ctx, &runnersv1.GetWorkloadRequest{Id: workloadID.String()}) + if err != nil { + t.Fatalf("GetWorkload failed: %v", err) + } + if resp.GetWorkload().GetAgentState() != runnersv1.WorkloadAgentState_WORKLOAD_AGENT_STATE_IDLE { + t.Fatalf("expected agent state %v, got %v", runnersv1.WorkloadAgentState_WORKLOAD_AGENT_STATE_IDLE, resp.GetWorkload().GetAgentState()) + } + if resp.GetWorkload().GetAgentName() != agentName { + t.Fatalf("expected agent name %q, got %q", agentName, resp.GetWorkload().GetAgentName()) + } + if resp.GetWorkload().GetRunnerName() != runnerName { + t.Fatalf("expected runner name %q, got %q", runnerName, resp.GetWorkload().GetRunnerName()) + } + + if err := mockPool.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + func TestTouchWorkload(t *testing.T) { mockPool, err := pgxmock.NewPool() if err != nil { @@ -903,6 +963,47 @@ func TestTouchWorkload(t *testing.T) { } } +func TestTouchWorkloadNoPublishWhenProcessing(t *testing.T) { + mockPool, err := pgxmock.NewPool() + if err != nil { + t.Fatalf("failed to create mock pool: %v", err) + } + + workloadID := uuid.New() + agentID := uuid.New() + callerID := agentID + + updateQuery := fmt.Sprintf("UPDATE workloads SET agent_state = $1, last_activity_at = NOW(), updated_at = NOW() WHERE id = $2 AND agent_id = $3 AND agent_state = $4 RETURNING %s", workloadColumns) + mockPool.ExpectQuery(regexp.QuoteMeta(updateQuery)). + WithArgs(workloadAgentStateProcessing, workloadID, callerID, workloadAgentStateIdle). + WillReturnRows(pgxmock.NewRows(workloadRowColumns)) + + query := "UPDATE workloads SET last_activity_at = NOW(), updated_at = NOW() WHERE id = $1 AND agent_id = $2" + mockPool.ExpectExec(regexp.QuoteMeta(query)). + WithArgs(workloadID, callerID). + WillReturnResult(pgxmock.NewResult("UPDATE", 1)) + + publishCalls := 0 + notificationsClient := fakeNotificationsClient{publish: func(ctx context.Context, req *notificationsv1.PublishRequest) (*notificationsv1.PublishResponse, error) { + publishCalls++ + return ¬ificationsv1.PublishResponse{}, nil + }} + + srv := New(Options{Pool: mockPool, NotificationsClient: notificationsClient}) + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(identityMetadata, callerID.String())) + _, err = srv.TouchWorkload(ctx, &runnersv1.TouchWorkloadRequest{Id: workloadID.String()}) + if err != nil { + t.Fatalf("TouchWorkload failed: %v", err) + } + if publishCalls != 0 { + t.Fatalf("expected no notifications, got %d", publishCalls) + } + + if err := mockPool.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + func TestTouchWorkloadPublishesUpdateWhenIdle(t *testing.T) { mockPool, err := pgxmock.NewPool() if err != nil {