From dc1b6d54f84df2c97bdd425973fd4618bc149656 Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Mon, 27 Apr 2026 21:29:52 +0000 Subject: [PATCH] fix(go-core): add runner identity context --- suites/go-core/tests/dedup_test.go | 10 ++++++---- suites/go-core/tests/expose_test.go | 7 ++++--- suites/go-core/tests/idle_test.go | 9 +++++---- suites/go-core/tests/imagepull_test.go | 5 +++-- suites/go-core/tests/mcp_test.go | 7 ++++--- suites/go-core/tests/multi_test.go | 20 ++++++++++--------- suites/go-core/tests/pipeline_test.go | 5 +++-- suites/go-core/tests/start_test.go | 7 ++++--- suites/go-core/tests/threads_send_test.go | 5 +++-- .../tests/workload_start_retry_policy_test.go | 13 ++++++------ 10 files changed, 50 insertions(+), 38 deletions(-) diff --git a/suites/go-core/tests/dedup_test.go b/suites/go-core/tests/dedup_test.go index cf05ad8..3c16719 100644 --- a/suites/go-core/tests/dedup_test.go +++ b/suites/go-core/tests/dedup_test.go @@ -37,6 +37,7 @@ func TestNoDuplicateWorkloads(t *testing.T) { identityID := resolveOrCreateUser(t, ctx, usersClient) threadsCtx := withIdentity(ctx, identityID) + runnerCtx := withIdentity(ctx, identityID) token := createAPIToken(t, ctx, usersClient, identityID) orgID := createTestOrganization(t, ctx, orgsClient, identityID) @@ -81,17 +82,18 @@ func TestNoDuplicateWorkloads(t *testing.T) { agentCleanupCtx := withIdentity(cleanupCtx, agentID) ackAllUnackedMessagesBestEffort(t, agentCleanupCtx, threadsClient, agentID) - ids, err := findWorkloadsByLabels(cleanupCtx, runnerClient, labels) + runnerCleanupCtx := withIdentity(cleanupCtx, identityID) + ids, err := findWorkloadsByLabels(runnerCleanupCtx, runnerClient, labels) if err != nil { t.Logf("cleanup: find workloads: %v", err) return } for _, workloadID := range ids { - cleanupWorkload(t, cleanupCtx, runnerClient, workloadID) + cleanupWorkload(t, runnerCleanupCtx, runnerClient, workloadID) } }) - pollCtx, pollCancel := context.WithTimeout(ctx, 90*time.Second) + pollCtx, pollCancel := context.WithTimeout(runnerCtx, 90*time.Second) defer pollCancel() if err := pollUntil(pollCtx, pollInterval, func(ctx context.Context) error { ids, err := findWorkloadsByLabels(ctx, runnerClient, labels) @@ -112,7 +114,7 @@ func TestNoDuplicateWorkloads(t *testing.T) { defer ticker.Stop() for { - ids, err := findWorkloadsByLabels(ctx, runnerClient, labels) + ids, err := findWorkloadsByLabels(runnerCtx, runnerClient, labels) if err != nil { t.Fatalf("find workloads: %v", err) } diff --git a/suites/go-core/tests/expose_test.go b/suites/go-core/tests/expose_test.go index 9743321..9c6cc00 100644 --- a/suites/go-core/tests/expose_test.go +++ b/suites/go-core/tests/expose_test.go @@ -40,6 +40,7 @@ func TestAgentExposeListExec(t *testing.T) { identityID := resolveOrCreateUser(t, ctx, usersClient) threadsCtx := withIdentity(ctx, identityID) + runnerCtx := withIdentity(ctx, identityID) token := createAPIToken(t, ctx, usersClient, identityID) orgID := createTestOrganization(t, ctx, orgsClient, identityID) @@ -75,13 +76,13 @@ func TestAgentExposeListExec(t *testing.T) { labelThreadID: threadID, } t.Cleanup(func() { - ids, err := findWorkloadsByLabels(ctx, runnerClient, labels) + ids, err := findWorkloadsByLabels(runnerCtx, runnerClient, labels) if err != nil { t.Logf("cleanup: find workloads: %v", err) return } for _, workloadID := range ids { - cleanupWorkload(t, ctx, runnerClient, workloadID) + cleanupWorkload(t, runnerCtx, runnerClient, workloadID) } }) @@ -99,7 +100,7 @@ func TestAgentExposeListExec(t *testing.T) { t.Fatalf("expected agent response %q, got %q", expectedResponse, agentBody) } - workloadIDs, err := findWorkloadsByLabels(ctx, runnerClient, labels) + workloadIDs, err := findWorkloadsByLabels(runnerCtx, runnerClient, labels) if err != nil { t.Fatalf("find workloads: %v", err) } diff --git a/suites/go-core/tests/idle_test.go b/suites/go-core/tests/idle_test.go index e87711e..b56f4ef 100644 --- a/suites/go-core/tests/idle_test.go +++ b/suites/go-core/tests/idle_test.go @@ -46,6 +46,7 @@ func TestWorkloadStopsAfterIdleTimeout(t *testing.T) { identityID := resolveOrCreateUser(t, ctx, usersClient) threadsCtx := withIdentity(ctx, identityID) + runnerCtx := withIdentity(ctx, identityID) token := createAPIToken(t, ctx, usersClient, identityID) orgID := createTestOrganization(t, ctx, orgsClient, identityID) @@ -108,10 +109,10 @@ func TestWorkloadStopsAfterIdleTimeout(t *testing.T) { if workloadID == "" { return } - cleanupWorkload(t, ctx, runnerClient, workloadID) + cleanupWorkload(t, runnerCtx, runnerClient, workloadID) }) - pollCtx, pollCancel := context.WithTimeout(ctx, workloadWaitTimeout) + pollCtx, pollCancel := context.WithTimeout(runnerCtx, workloadWaitTimeout) defer pollCancel() if err := pollUntil(pollCtx, pollInterval, func(ctx context.Context) error { ids, err := findWorkloadsByLabels(ctx, runnerClient, labels) @@ -150,7 +151,7 @@ func TestWorkloadStopsAfterIdleTimeout(t *testing.T) { t.Fatalf("wait for unacked messages to drain: %v", err) } - idleCtx, idleCancel := context.WithTimeout(ctx, idleStopTimeout) + idleCtx, idleCancel := context.WithTimeout(runnerCtx, idleStopTimeout) defer idleCancel() if err := pollUntil(idleCtx, pollInterval, func(ctx context.Context) error { logRunnersWorkloadState(t, ctx, runnersClient, workloadID) @@ -163,7 +164,7 @@ func TestWorkloadStopsAfterIdleTimeout(t *testing.T) { } return nil }); err != nil { - diagCtx, cancelDiag := context.WithTimeout(context.Background(), 10*time.Second) + diagCtx, cancelDiag := context.WithTimeout(runnerCtx, 10*time.Second) defer cancelDiag() logRunnersWorkloadState(t, diagCtx, runnersClient, workloadID) t.Fatalf("wait for workload stop: %v", err) diff --git a/suites/go-core/tests/imagepull_test.go b/suites/go-core/tests/imagepull_test.go index 614ad18..d7f6112 100644 --- a/suites/go-core/tests/imagepull_test.go +++ b/suites/go-core/tests/imagepull_test.go @@ -50,6 +50,7 @@ func TestImagePullSecretAttachedToPod(t *testing.T) { identityID := resolveOrCreateUser(t, ctx, usersClient) threadsCtx := withIdentity(ctx, identityID) + runnerCtx := withIdentity(ctx, identityID) token := createAPIToken(t, ctx, usersClient, identityID) orgID := createTestOrganization(t, ctx, orgsClient, identityID) @@ -115,10 +116,10 @@ func TestImagePullSecretAttachedToPod(t *testing.T) { if workloadID == "" { return } - cleanupWorkload(t, ctx, runnerClient, workloadID) + cleanupWorkload(t, runnerCtx, runnerClient, workloadID) }) - pollCtx, pollCancel := context.WithTimeout(ctx, 90*time.Second) + pollCtx, pollCancel := context.WithTimeout(runnerCtx, 90*time.Second) defer pollCancel() if err := pollUntil(pollCtx, pollInterval, func(ctx context.Context) error { ids, err := findWorkloadsByLabels(ctx, runnerClient, labelsMap) diff --git a/suites/go-core/tests/mcp_test.go b/suites/go-core/tests/mcp_test.go index f7b0433..ddb44fe 100644 --- a/suites/go-core/tests/mcp_test.go +++ b/suites/go-core/tests/mcp_test.go @@ -46,6 +46,7 @@ func runMCPToolsE2E(t *testing.T, llmEndpoint, initImage string) pipelineRun { identityID := resolveOrCreateUser(t, ctx, usersClient) threadsCtx := withIdentity(ctx, identityID) + runnerCtx := withIdentity(ctx, identityID) token := createAPIToken(t, ctx, usersClient, identityID) orgID := createTestOrganization(t, ctx, orgsClient, identityID) @@ -111,13 +112,13 @@ func runMCPToolsE2E(t *testing.T, llmEndpoint, initImage string) pipelineRun { labelThreadID: threadID, } t.Cleanup(func() { - ids, err := findWorkloadsByLabels(ctx, runnerClient, labels) + ids, err := findWorkloadsByLabels(runnerCtx, runnerClient, labels) if err != nil { t.Logf("cleanup: find workloads: %v", err) return } for _, workloadID := range ids { - cleanupWorkload(t, ctx, runnerClient, workloadID) + cleanupWorkload(t, runnerCtx, runnerClient, workloadID) } }) @@ -129,7 +130,7 @@ func runMCPToolsE2E(t *testing.T, llmEndpoint, initImage string) pipelineRun { expected := "I've created the entity 'test_project' (type: project) with the observation 'A test project'. The /test-data directory contains one file: hello.txt." - readyCtx, readyCancel := context.WithTimeout(ctx, 4*time.Minute) + readyCtx, readyCancel := context.WithTimeout(runnerCtx, 4*time.Minute) defer readyCancel() if err := waitForMcpSidecarsReady(t, readyCtx, runnerClient, labels); err != nil { t.Fatalf("wait for mcp sidecars: %v", err) diff --git a/suites/go-core/tests/multi_test.go b/suites/go-core/tests/multi_test.go index a83740a..1fb923b 100644 --- a/suites/go-core/tests/multi_test.go +++ b/suites/go-core/tests/multi_test.go @@ -37,6 +37,7 @@ func TestMultipleAgentsSeparateThreads(t *testing.T) { identityID := resolveOrCreateUser(t, ctx, usersClient) threadsCtx := withIdentity(ctx, identityID) + runnerCtx := withIdentity(ctx, identityID) token := createAPIToken(t, ctx, usersClient, identityID) orgID := createTestOrganization(t, ctx, orgsClient, identityID) @@ -91,16 +92,16 @@ func TestMultipleAgentsSeparateThreads(t *testing.T) { workloadBID := "" t.Cleanup(func() { if workloadAID != "" { - cleanupWorkload(t, ctx, runnerClient, workloadAID) + cleanupWorkload(t, runnerCtx, runnerClient, workloadAID) } }) t.Cleanup(func() { if workloadBID != "" { - cleanupWorkload(t, ctx, runnerClient, workloadBID) + cleanupWorkload(t, runnerCtx, runnerClient, workloadBID) } }) - pollCtx, pollCancel := context.WithTimeout(ctx, 90*time.Second) + pollCtx, pollCancel := context.WithTimeout(runnerCtx, 90*time.Second) defer pollCancel() if err := pollUntil(pollCtx, pollInterval, func(ctx context.Context) error { ids, err := findWorkloadsByLabels(ctx, runnerClient, labelsA) @@ -116,7 +117,7 @@ func TestMultipleAgentsSeparateThreads(t *testing.T) { t.Fatalf("wait for workload A: %v", err) } - pollCtxB, pollCancelB := context.WithTimeout(ctx, 90*time.Second) + pollCtxB, pollCancelB := context.WithTimeout(runnerCtx, 90*time.Second) defer pollCancelB() if err := pollUntil(pollCtxB, pollInterval, func(ctx context.Context) error { ids, err := findWorkloadsByLabels(ctx, runnerClient, labelsB) @@ -136,11 +137,11 @@ func TestMultipleAgentsSeparateThreads(t *testing.T) { t.Fatalf("expected distinct workloads, got %s", workloadAID) } - labelsRespA, err := getWorkloadLabels(ctx, runnerClient, workloadAID) + labelsRespA, err := getWorkloadLabels(runnerCtx, runnerClient, workloadAID) if err != nil { t.Fatalf("get labels for workload A: %v", err) } - labelsRespB, err := getWorkloadLabels(ctx, runnerClient, workloadBID) + labelsRespB, err := getWorkloadLabels(runnerCtx, runnerClient, workloadBID) if err != nil { t.Fatalf("get labels for workload B: %v", err) } @@ -172,6 +173,7 @@ func TestSameAgentMultipleThreads(t *testing.T) { identityID := resolveOrCreateUser(t, ctx, usersClient) threadsCtx := withIdentity(ctx, identityID) + runnerCtx := withIdentity(ctx, identityID) token := createAPIToken(t, ctx, usersClient, identityID) orgID := createTestOrganization(t, ctx, orgsClient, identityID) @@ -215,11 +217,11 @@ func TestSameAgentMultipleThreads(t *testing.T) { workloadIDs := []string{} t.Cleanup(func() { for _, workloadID := range workloadIDs { - cleanupWorkload(t, ctx, runnerClient, workloadID) + cleanupWorkload(t, runnerCtx, runnerClient, workloadID) } }) - pollCtx, pollCancel := context.WithTimeout(ctx, 90*time.Second) + pollCtx, pollCancel := context.WithTimeout(runnerCtx, 90*time.Second) defer pollCancel() if err := pollUntil(pollCtx, pollInterval, func(ctx context.Context) error { ids, err := findWorkloadsByLabels(ctx, runnerClient, labels) @@ -241,7 +243,7 @@ func TestSameAgentMultipleThreads(t *testing.T) { } foundThreads := map[string]bool{} for _, workloadID := range workloadIDs { - labelsResp, err := getWorkloadLabels(ctx, runnerClient, workloadID) + labelsResp, err := getWorkloadLabels(runnerCtx, runnerClient, workloadID) if err != nil { t.Fatalf("get labels for workload %s: %v", workloadID, err) } diff --git a/suites/go-core/tests/pipeline_test.go b/suites/go-core/tests/pipeline_test.go index f5f8084..4ef9783 100644 --- a/suites/go-core/tests/pipeline_test.go +++ b/suites/go-core/tests/pipeline_test.go @@ -47,6 +47,7 @@ func runFullPipelineMessageResponse(t *testing.T, llmEndpoint, initImage, messag identityID := resolveOrCreateUser(t, ctx, usersClient) threadsCtx := withIdentity(ctx, identityID) + runnerCtx := withIdentity(ctx, identityID) token := createAPIToken(t, ctx, usersClient, identityID) orgID := createTestOrganization(t, ctx, orgsClient, identityID) @@ -86,13 +87,13 @@ func runFullPipelineMessageResponse(t *testing.T, llmEndpoint, initImage, messag labelThreadID: threadID, } t.Cleanup(func() { - ids, err := findWorkloadsByLabels(ctx, runnerClient, labels) + ids, err := findWorkloadsByLabels(runnerCtx, runnerClient, labels) if err != nil { t.Logf("cleanup: find workloads: %v", err) return } for _, workloadID := range ids { - cleanupWorkload(t, ctx, runnerClient, workloadID) + cleanupWorkload(t, runnerCtx, runnerClient, workloadID) } }) diff --git a/suites/go-core/tests/start_test.go b/suites/go-core/tests/start_test.go index ab62841..9903b74 100644 --- a/suites/go-core/tests/start_test.go +++ b/suites/go-core/tests/start_test.go @@ -37,6 +37,7 @@ func TestWorkloadStartsOnUnackedMessage(t *testing.T) { identityID := resolveOrCreateUser(t, ctx, usersClient) threadsCtx := withIdentity(ctx, identityID) + runnerCtx := withIdentity(ctx, identityID) token := createAPIToken(t, ctx, usersClient, identityID) orgID := createTestOrganization(t, ctx, orgsClient, identityID) @@ -74,7 +75,7 @@ func TestWorkloadStartsOnUnackedMessage(t *testing.T) { labelThreadID: threadID, } - pollCtx, pollCancel := context.WithTimeout(ctx, 90*time.Second) + pollCtx, pollCancel := context.WithTimeout(runnerCtx, 90*time.Second) defer pollCancel() workloadID := "" if err := pollUntil(pollCtx, pollInterval, func(ctx context.Context) error { @@ -91,9 +92,9 @@ func TestWorkloadStartsOnUnackedMessage(t *testing.T) { t.Fatalf("wait for workload: %v", err) } - t.Cleanup(func() { cleanupWorkload(t, ctx, runnerClient, workloadID) }) + t.Cleanup(func() { cleanupWorkload(t, runnerCtx, runnerClient, workloadID) }) - labelsResp, err := getWorkloadLabels(ctx, runnerClient, workloadID) + labelsResp, err := getWorkloadLabels(runnerCtx, runnerClient, workloadID) if err != nil { t.Fatalf("get workload labels: %v", err) } diff --git a/suites/go-core/tests/threads_send_test.go b/suites/go-core/tests/threads_send_test.go index a6cc260..f9b11eb 100644 --- a/suites/go-core/tests/threads_send_test.go +++ b/suites/go-core/tests/threads_send_test.go @@ -38,6 +38,7 @@ func TestThreadsSendShell(t *testing.T) { identityID := resolveOrCreateUser(t, ctx, usersClient) threadsCtx := withIdentity(ctx, identityID) + runnerCtx := withIdentity(ctx, identityID) token := createAPIToken(t, ctx, usersClient, identityID) orgID := createTestOrganization(t, ctx, orgsClient, identityID) @@ -77,13 +78,13 @@ func TestThreadsSendShell(t *testing.T) { labelThreadID: threadID, } t.Cleanup(func() { - ids, err := findWorkloadsByLabels(ctx, runnerClient, labels) + ids, err := findWorkloadsByLabels(runnerCtx, runnerClient, labels) if err != nil { t.Logf("cleanup: find workloads: %v", err) return } for _, workloadID := range ids { - cleanupWorkload(t, ctx, runnerClient, workloadID) + cleanupWorkload(t, runnerCtx, runnerClient, workloadID) } }) diff --git a/suites/go-core/tests/workload_start_retry_policy_test.go b/suites/go-core/tests/workload_start_retry_policy_test.go index 6725025..9edf372 100644 --- a/suites/go-core/tests/workload_start_retry_policy_test.go +++ b/suites/go-core/tests/workload_start_retry_policy_test.go @@ -60,6 +60,7 @@ func TestWorkloadStartRetryPolicyFastRetry(t *testing.T) { identityID := resolveOrCreateUser(t, ctx, usersClient) threadsCtx := withIdentity(ctx, identityID) + runnerCtx := withIdentity(ctx, identityID) token := createAPIToken(t, ctx, usersClient, identityID) orgID := createTestOrganization(t, ctx, orgsClient, identityID) @@ -98,17 +99,17 @@ func TestWorkloadStartRetryPolicyFastRetry(t *testing.T) { labelThreadID: threadID, } t.Cleanup(func() { - ids, err := findWorkloadsByLabels(ctx, runnerClient, labels) + ids, err := findWorkloadsByLabels(runnerCtx, runnerClient, labels) if err != nil { t.Logf("cleanup: find workloads: %v", err) return } for _, workloadID := range ids { - cleanupWorkload(t, ctx, runnerClient, workloadID) + cleanupWorkload(t, runnerCtx, runnerClient, workloadID) } }) - failureCtx, failureCancel := context.WithTimeout(ctx, failedWorkloadTimeout) + failureCtx, failureCancel := context.WithTimeout(runnerCtx, failedWorkloadTimeout) defer failureCancel() failedWorkloads, err := waitForFailedWorkloads(failureCtx, runnersClient, threadID, agentID, 2) if err != nil { @@ -122,7 +123,7 @@ func TestWorkloadStartRetryPolicyFastRetry(t *testing.T) { assertFailedWorkload(t, failedLatest, threadID, agentID) assertFailedWorkload(t, failedPrevious, threadID, agentID) - allWorkloads, err := listWorkloadsByThread(ctx, runnersClient, threadID, agentID, nil) + allWorkloads, err := listWorkloadsByThread(runnerCtx, runnersClient, threadID, agentID, nil) if err != nil { t.Fatalf("list workloads by thread: %v", err) } @@ -159,7 +160,7 @@ func TestWorkloadStartRetryPolicyFastRetry(t *testing.T) { t.Fatalf("update agent init image: %v", err) } - fastRetryCtx, fastRetryCancel := context.WithTimeout(ctx, fastRetryTimeout) + fastRetryCtx, fastRetryCancel := context.WithTimeout(runnerCtx, fastRetryTimeout) defer fastRetryCancel() retryWorkload, err := waitForRetryWorkload(fastRetryCtx, runnersClient, threadID, agentID, removedAt) if err != nil { @@ -185,7 +186,7 @@ func TestWorkloadStartRetryPolicyFastRetry(t *testing.T) { if instanceID == "" { t.Fatalf("failed workload %s missing instance id", workloadID(t, failed)) } - cleanupCtx, cleanupCancel := context.WithTimeout(ctx, runnerCleanupTimeout) + cleanupCtx, cleanupCancel := context.WithTimeout(runnerCtx, runnerCleanupTimeout) if err := waitForRunnerWorkloadGone(cleanupCtx, runnerClient, instanceID); err != nil { cleanupCancel() t.Fatalf("wait for runner workload %s cleanup: %v", instanceID, err)