diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 428ac39..efd01da 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -9,6 +9,13 @@ concurrency: group: e2e-${{ github.event_name == 'pull_request' && github.event.pull_request.number || github.sha }} cancel-in-progress: true +env: + AGN_INIT_IMAGE: ghcr.io/agynio/agent-init-agn:0.4 + CODEX_INIT_IMAGE: ghcr.io/agynio/agent-init-codex:0.13 + CLAUDE_INIT_IMAGE: ghcr.io/agynio/agent-init-claude:0.1 + AGN_EXPOSE_INIT_IMAGE: ghcr.io/agynio/agent-init-agn:0.4 + AGYN_AGENT_INIT_IMAGE: ghcr.io/agynio/agent-init:v1.0.0 + jobs: e2e: runs-on: ubuntu-latest @@ -46,5 +53,11 @@ jobs: - name: Run E2E tests uses: agynio/e2e/.github/actions/run-tests@main + env: + AGN_INIT_IMAGE: ${{ env.AGN_INIT_IMAGE }} + CODEX_INIT_IMAGE: ${{ env.CODEX_INIT_IMAGE }} + CLAUDE_INIT_IMAGE: ${{ env.CLAUDE_INIT_IMAGE }} + AGN_EXPOSE_INIT_IMAGE: ${{ env.AGN_EXPOSE_INIT_IMAGE }} + AGYN_AGENT_INIT_IMAGE: ${{ env.AGYN_AGENT_INIT_IMAGE }} with: service: runners diff --git a/README.md b/README.md index 333124b..a3b9783 100644 --- a/README.md +++ b/README.md @@ -38,3 +38,11 @@ devspace run test-e2e --tag svc_runners E2E coverage is centralized in [agynio/e2e](https://github.com/agynio/e2e) under the go-core suite. See [E2E Testing](https://github.com/agynio/architecture/blob/main/architecture/operations/e2e-testing.md). + +## Helm chart defaults + +The chart ships with a DENY-based Istio AuthorizationPolicy. By default, +`authorizationPolicy.identityServiceAccounts` allows in-mesh callers that +forward `x-identity-id` (`gateway`, `expose`, `notifications`, `chat`). +Override the list if your deployment uses different service accounts (for +example, add `agents-orchestrator-e2e` for E2E runs). diff --git a/charts/runners/templates/authorizationpolicy.yaml b/charts/runners/templates/authorizationpolicy.yaml new file mode 100644 index 0000000..3840fbe --- /dev/null +++ b/charts/runners/templates/authorizationpolicy.yaml @@ -0,0 +1,61 @@ +{{- if .Values.authorizationPolicy.enabled }} +{{- $orchestratorNamespace := .Values.authorizationPolicy.orchestratorServiceAccount.namespace | default .Release.Namespace }} +{{- $orchestratorName := .Values.authorizationPolicy.orchestratorServiceAccount.name }} +{{- $orchestratorPrincipal := printf "cluster.local/ns/%s/sa/%s" $orchestratorNamespace $orchestratorName }} +{{- $identityServiceAccounts := .Values.authorizationPolicy.identityServiceAccounts | default list }} +{{- $identityPrincipals := list $orchestratorPrincipal }} +{{- range $identityServiceAccounts }} +{{- $identityNamespace := .namespace | default $.Release.Namespace }} +{{- $identityPrincipals = append $identityPrincipals (printf "cluster.local/ns/%s/sa/%s" $identityNamespace .name) }} +{{- end }} +{{- $gatewayPaths := list "/agynio.api.runners.v1.RunnersService/RegisterRunner" "/agynio.api.runners.v1.RunnersService/GetRunner" "/agynio.api.runners.v1.RunnersService/ListRunners" "/agynio.api.runners.v1.RunnersService/UpdateRunner" "/agynio.api.runners.v1.RunnersService/DeleteRunner" "/agynio.api.runners.v1.RunnersService/GetWorkload" "/agynio.api.runners.v1.RunnersService/ListWorkloadsByThread" "/agynio.api.runners.v1.RunnersService/ListWorkloads" "/agynio.api.runners.v1.RunnersService/TouchWorkload" "/agynio.api.runners.v1.RunnersService/StreamWorkloadLogs" "/agynio.api.runners.v1.RunnersService/GetVolume" "/agynio.api.runners.v1.RunnersService/ListVolumes" "/agynio.api.runners.v1.RunnersService/ListVolumesByThread" }} +{{- $orchestratorPaths := list "/agynio.api.runners.v1.RunnersService/CreateWorkload" "/agynio.api.runners.v1.RunnersService/UpdateWorkload" "/agynio.api.runners.v1.RunnersService/UpdateWorkloadStatus" "/agynio.api.runners.v1.RunnersService/DeleteWorkload" "/agynio.api.runners.v1.RunnersService/BatchUpdateWorkloadSampledAt" "/agynio.api.runners.v1.RunnersService/CreateVolume" "/agynio.api.runners.v1.RunnersService/UpdateVolume" "/agynio.api.runners.v1.RunnersService/BatchUpdateVolumeSampledAt" }} +apiVersion: security.istio.io/v1beta1 +kind: AuthorizationPolicy +metadata: + name: {{ include "service-base.fullname" . }}-internal + labels: + {{- include "service-base.labels" . | nindent 4 }} +spec: + selector: + matchLabels: + {{- include "service-base.selectorLabels" . | nindent 6 }} + action: DENY + rules: + - from: + - source: + notPrincipals: + - {{ $orchestratorPrincipal | quote }} + to: + - operation: + paths: + {{ range $orchestratorPaths }} + - {{ . | quote }} + {{ end }} + - from: + - source: + notPrincipals: + {{ range $identityPrincipals }} + - {{ . | quote }} + {{ end }} + to: + - operation: + paths: + {{ range $gatewayPaths }} + - {{ . | quote }} + {{ end }} + - from: + - source: + notPrincipals: + - {{ $orchestratorPrincipal | quote }} + when: + - key: request.headers[x-identity-id] + notValues: + - "*" + to: + - operation: + paths: + {{ range $gatewayPaths }} + - {{ . | quote }} + {{ end }} +{{- end }} diff --git a/charts/runners/values.yaml b/charts/runners/values.yaml index 8d51d19..341bb29 100644 --- a/charts/runners/values.yaml +++ b/charts/runners/values.yaml @@ -25,6 +25,21 @@ serviceAccount: automountServiceAccountToken: true +authorizationPolicy: + enabled: true + orchestratorServiceAccount: + name: agents-orchestrator + namespace: "" + identityServiceAccounts: + - name: gateway + namespace: "" + - name: expose + namespace: "" + - name: notifications + namespace: "" + - name: chat + namespace: "" + rbac: create: false clusterWide: false diff --git a/internal/server/authz.go b/internal/server/authz.go index f7b347f..762e24f 100644 --- a/internal/server/authz.go +++ b/internal/server/authz.go @@ -18,11 +18,30 @@ func identityFromMetadata(ctx context.Context) (uuid.UUID, error) { } values := md.Get(identityMetadata) if len(values) != 1 { - return uuid.UUID{}, fmt.Errorf("expected single value") + return uuid.UUID{}, fmt.Errorf("metadata %s: expected single value, got %d", identityMetadata, len(values)) } return parseUUID(values[0]) } +func identityFromMetadataOptional(ctx context.Context) (*uuid.UUID, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, nil + } + values := md.Get(identityMetadata) + if len(values) == 0 { + return nil, nil + } + if len(values) != 1 { + return nil, fmt.Errorf("metadata %s: expected single value, got %d", identityMetadata, len(values)) + } + parsed, err := parseUUID(values[0]) + if err != nil { + return nil, err + } + return &parsed, nil +} + func (s *Server) requireClusterAdmin(ctx context.Context, identityID uuid.UUID) error { return s.requireRelation(ctx, identityID, clusterAdminRelation, clusterObject) } diff --git a/internal/server/runners.go b/internal/server/runners.go index 7ade858..663e6d6 100644 --- a/internal/server/runners.go +++ b/internal/server/runners.go @@ -177,7 +177,7 @@ func (s *Server) RegisterRunner(ctx context.Context, req *runnersv1.RegisterRunn } func (s *Server) GetRunner(ctx context.Context, req *runnersv1.GetRunnerRequest) (*runnersv1.GetRunnerResponse, error) { - callerID, err := identityFromMetadata(ctx) + callerID, err := identityFromMetadataOptional(ctx) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "unauthenticated: %v", err) } @@ -189,8 +189,8 @@ func (s *Server) GetRunner(ctx context.Context, req *runnersv1.GetRunnerRequest) if err != nil { return nil, toStatusError(err) } - if runner.OrganizationID != nil { - if err := s.requireOrgMember(ctx, callerID, *runner.OrganizationID); err != nil { + if callerID != nil && runner.OrganizationID != nil { + if err := s.requireOrgMember(ctx, *callerID, *runner.OrganizationID); err != nil { return nil, err } } @@ -259,7 +259,7 @@ func (s *Server) UpdateRunner(ctx context.Context, req *runnersv1.UpdateRunnerRe } func (s *Server) ListRunners(ctx context.Context, req *runnersv1.ListRunnersRequest) (*runnersv1.ListRunnersResponse, error) { - callerID, err := identityFromMetadata(ctx) + callerID, err := identityFromMetadataOptional(ctx) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "unauthenticated: %v", err) } @@ -271,8 +271,8 @@ func (s *Server) ListRunners(ctx context.Context, req *runnersv1.ListRunnersRequ } organizationID = &parsed } - if organizationID != nil { - if err := s.requireOrgMember(ctx, callerID, *organizationID); err != nil { + if callerID != nil && organizationID != nil { + if err := s.requireOrgMember(ctx, *callerID, *organizationID); err != nil { return nil, err } } @@ -286,7 +286,7 @@ func (s *Server) ListRunners(ctx context.Context, req *runnersv1.ListRunnersRequ return nil, status.Errorf(codes.Internal, "list runners: %v", err) } - if organizationID == nil { + if callerID != nil && organizationID == nil { memberCache := map[uuid.UUID]bool{} filtered := make([]runnerRecord, 0, len(runners)) for _, runner := range runners { @@ -294,7 +294,7 @@ func (s *Server) ListRunners(ctx context.Context, req *runnersv1.ListRunnersRequ filtered = append(filtered, runner) continue } - allowed, err := s.memberAllowed(ctx, callerID, *runner.OrganizationID, memberCache) + allowed, err := s.memberAllowed(ctx, *callerID, *runner.OrganizationID, memberCache) if err != nil { return nil, err } diff --git a/internal/server/runners_test.go b/internal/server/runners_test.go index e0704a5..596758f 100644 --- a/internal/server/runners_test.go +++ b/internal/server/runners_test.go @@ -139,6 +139,7 @@ func (f fakeIdentityClient) ResolveNickname(ctx context.Context, req *identityv1 func (f fakeIdentityClient) BatchGetNicknames(ctx context.Context, req *identityv1.BatchGetNicknamesRequest, opts ...grpc.CallOption) (*identityv1.BatchGetNicknamesResponse, error) { return nil, status.Error(codes.Unimplemented, "not implemented") } + type fakeAgentsClient struct { getAgent func(ctx context.Context, req *agentsv1.GetAgentRequest) (*agentsv1.GetAgentResponse, error) getVolume func(ctx context.Context, req *agentsv1.GetVolumeRequest) (*agentsv1.GetVolumeResponse, error) @@ -189,6 +190,7 @@ func (f fakeAgentsClient) GetHook(ctx context.Context, req *agentsv1.GetHookRequ } return f.getHook(ctx, req) } + type fakeAuthorizationClient struct { check func(ctx context.Context, req *authorizationv1.CheckRequest) (*authorizationv1.CheckResponse, error) write func(ctx context.Context, req *authorizationv1.WriteRequest) (*authorizationv1.WriteResponse, error) @@ -491,6 +493,52 @@ func TestGetRunnerRequiresMember(t *testing.T) { } } +func TestGetRunnerInternalNoIdentity(t *testing.T) { + mockPool, err := pgxmock.NewPool() + if err != nil { + t.Fatalf("failed to create mock pool: %v", err) + } + + runnerID := uuid.New() + organizationID := uuid.New() + identityID := uuid.New() + now := time.Now().UTC() + labelsJSON := []byte("{}") + capabilitiesJSON := []byte("[]") + rows := pgxmock.NewRows([]string{"id", "name", "organization_id", "identity_id", "ziti_identity_id", "ziti_service_id", "ziti_service_name", "status", "labels", "capabilities", "created_at", "updated_at"}). + AddRow(runnerID, "runner-1", pgtype.UUID{Bytes: organizationID, Valid: true}, identityID, "", "service-id", "runner-service", runnerStatusOffline, labelsJSON, capabilitiesJSON, now, now) + + query := fmt.Sprintf(`SELECT %s FROM runners WHERE id = $1`, runnerColumns) + mockPool.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(runnerID).WillReturnRows(rows) + + checkCalls := 0 + authorizationClient := fakeAuthorizationClient{ + check: func(ctx context.Context, req *authorizationv1.CheckRequest) (*authorizationv1.CheckResponse, error) { + checkCalls++ + return &authorizationv1.CheckResponse{Allowed: false}, nil + }, + } + + srv := New(Options{Pool: mockPool, AuthorizationClient: authorizationClient}) + resp, err := srv.GetRunner(context.Background(), &runnersv1.GetRunnerRequest{Id: runnerID.String()}) + if err != nil { + t.Fatalf("GetRunner failed: %v", err) + } + if resp.GetRunner() == nil { + t.Fatal("expected runner in response") + } + if resp.GetRunner().GetOrganizationId() != organizationID.String() { + t.Fatalf("expected organization id %q, got %q", organizationID.String(), resp.GetRunner().GetOrganizationId()) + } + if checkCalls != 0 { + t.Fatalf("expected no authorization checks, got %d", checkCalls) + } + + if err := mockPool.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + func TestListRunnersRequiresMember(t *testing.T) { mockPool, err := pgxmock.NewPool() if err != nil { @@ -519,6 +567,53 @@ func TestListRunnersRequiresMember(t *testing.T) { } } +func TestListRunnersInternalNoIdentity(t *testing.T) { + mockPool, err := pgxmock.NewPool() + if err != nil { + t.Fatalf("failed to create mock pool: %v", err) + } + + runnerID := uuid.New() + organizationID := uuid.New() + identityID := uuid.New() + now := time.Now().UTC() + labelsJSON := []byte("{}") + capabilitiesJSON := []byte("[]") + rows := pgxmock.NewRows([]string{"id", "name", "organization_id", "identity_id", "ziti_identity_id", "ziti_service_id", "ziti_service_name", "status", "labels", "capabilities", "created_at", "updated_at"}). + AddRow(runnerID, "runner-1", pgtype.UUID{Bytes: organizationID, Valid: true}, identityID, "", "service-id", "runner-service", runnerStatusOffline, labelsJSON, capabilitiesJSON, now, now) + + limit := normalizePageSize(0) + query := fmt.Sprintf("SELECT %s FROM runners ORDER BY id ASC LIMIT $1", runnerColumns) + mockPool.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(int(limit) + 1).WillReturnRows(rows) + + checkCalls := 0 + authorizationClient := fakeAuthorizationClient{ + check: func(ctx context.Context, req *authorizationv1.CheckRequest) (*authorizationv1.CheckResponse, error) { + checkCalls++ + return &authorizationv1.CheckResponse{Allowed: false}, nil + }, + } + + srv := New(Options{Pool: mockPool, AuthorizationClient: authorizationClient}) + resp, err := srv.ListRunners(context.Background(), &runnersv1.ListRunnersRequest{}) + if err != nil { + t.Fatalf("ListRunners failed: %v", err) + } + if len(resp.GetRunners()) != 1 { + t.Fatalf("expected 1 runner, got %d", len(resp.GetRunners())) + } + if resp.GetRunners()[0].GetOrganizationId() != organizationID.String() { + t.Fatalf("expected organization id %q, got %q", organizationID.String(), resp.GetRunners()[0].GetOrganizationId()) + } + if checkCalls != 0 { + t.Fatalf("expected no authorization checks, got %d", checkCalls) + } + + if err := mockPool.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + func TestUpdateRunnerUpdatesLabels(t *testing.T) { mockPool, err := pgxmock.NewPool() if err != nil { diff --git a/internal/server/volumes.go b/internal/server/volumes.go index d03e403..8f17ba9 100644 --- a/internal/server/volumes.go +++ b/internal/server/volumes.go @@ -67,7 +67,7 @@ type volumeUpdateInput struct { } type volumeListFilter struct { - OrganizationID uuid.UUID + OrganizationID *uuid.UUID RunnerIDs []uuid.UUID Statuses []string AttachedKinds []runnersv1.VolumeAttachmentFilterKind @@ -312,7 +312,7 @@ func (s *Server) GetVolume(ctx context.Context, req *runnersv1.GetVolumeRequest) } func (s *Server) ListVolumesByThread(ctx context.Context, req *runnersv1.ListVolumesByThreadRequest) (*runnersv1.ListVolumesByThreadResponse, error) { - callerID, err := identityFromMetadata(ctx) + callerID, err := identityFromMetadataOptional(ctx) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "unauthenticated: %v", err) } @@ -328,18 +328,21 @@ func (s *Server) ListVolumesByThread(ctx context.Context, req *runnersv1.ListVol } return nil, status.Errorf(codes.Internal, "list volumes: %v", err) } - memberCache := map[uuid.UUID]bool{} - filtered := make([]volumeRecord, 0, len(volumes)) - for _, volume := range volumes { - allowed, err := s.memberAllowed(ctx, callerID, volume.OrganizationID, memberCache) - if err != nil { - return nil, err - } - if allowed { - filtered = append(filtered, volume) + if callerID != nil { + memberCache := map[uuid.UUID]bool{} + filtered := make([]volumeRecord, 0, len(volumes)) + for _, volume := range volumes { + allowed, err := s.memberAllowed(ctx, *callerID, volume.OrganizationID, memberCache) + if err != nil { + return nil, err + } + if allowed { + filtered = append(filtered, volume) + } } + volumes = filtered } - protoVolumes, err := toProtoVolumeList(filtered) + protoVolumes, err := toProtoVolumeList(volumes) if err != nil { return nil, status.Errorf(codes.Internal, "convert volumes: %v", err) } @@ -347,20 +350,30 @@ func (s *Server) ListVolumesByThread(ctx context.Context, req *runnersv1.ListVol } func (s *Server) ListVolumes(ctx context.Context, req *runnersv1.ListVolumesRequest) (*runnersv1.ListVolumesResponse, error) { - callerID, err := identityFromMetadata(ctx) + callerID, err := identityFromMetadataOptional(ctx) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "unauthenticated: %v", err) } orgValue := strings.TrimSpace(req.GetOrganizationId()) - if orgValue == "" { - return nil, status.Error(codes.InvalidArgument, "organization_id: value is empty") - } - organizationID, err := parseUUID(orgValue) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "organization_id: %v", err) - } - if err := s.requireRelation(ctx, callerID, organizationViewVolumes, organizationObject(organizationID)); err != nil { - return nil, err + var organizationID *uuid.UUID + if callerID != nil { + if orgValue == "" { + return nil, status.Error(codes.InvalidArgument, "organization_id: value is empty") + } + parsed, err := parseUUID(orgValue) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "organization_id: %v", err) + } + organizationID = &parsed + if err := s.requireRelation(ctx, *callerID, organizationViewVolumes, organizationObject(*organizationID)); err != nil { + return nil, err + } + } else if orgValue != "" { + parsed, err := parseUUID(orgValue) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "organization_id: %v", err) + } + organizationID = &parsed } filter := volumeListFilter{OrganizationID: organizationID} @@ -882,8 +895,12 @@ func volumePrimaryValue(item volumeListItem, field volumeSortField) (string, err } func (s *Server) listVolumeIDs(ctx context.Context, filter volumeListFilter) ([]uuid.UUID, error) { - clauses := []string{fmt.Sprintf("volumes.organization_id = $%d", 1)} - args := []any{filter.OrganizationID} + clauses := []string{} + args := []any{} + if filter.OrganizationID != nil { + clauses = append(clauses, fmt.Sprintf("volumes.organization_id = $%d", len(args)+1)) + args = append(args, *filter.OrganizationID) + } if len(filter.RunnerIDs) > 0 { clauses = append(clauses, fmt.Sprintf("volumes.runner_id = ANY($%d)", len(args)+1)) @@ -936,8 +953,12 @@ func addVolumeCursorClause(clauses *[]string, args *[]any, column string, direct func (s *Server) listVolumesPage(ctx context.Context, filter volumeListFilter, sort volumeListSort, pageSize int32, pageToken string) ([]volumeRecord, string, error) { limit := normalizePageSize(pageSize) - clauses := []string{fmt.Sprintf("volumes.organization_id = $%d", 1)} - args := []any{filter.OrganizationID} + clauses := []string{} + args := []any{} + if filter.OrganizationID != nil { + clauses = append(clauses, fmt.Sprintf("volumes.organization_id = $%d", len(args)+1)) + args = append(args, *filter.OrganizationID) + } if len(filter.RunnerIDs) > 0 { clauses = append(clauses, fmt.Sprintf("volumes.runner_id = ANY($%d)", len(args)+1)) @@ -1020,8 +1041,12 @@ func (s *Server) listVolumesPage(ctx context.Context, filter volumeListFilter, s func (s *Server) listVolumesByNamePage(ctx context.Context, filter volumeListFilter, sort volumeListSort, pageSize int32, pageToken string, volumeNames map[uuid.UUID]string) ([]volumeRecord, string, error) { limit := normalizePageSize(pageSize) - clauses := []string{fmt.Sprintf("volumes.organization_id = $%d", 1)} - args := []any{filter.OrganizationID} + clauses := []string{} + args := []any{} + if filter.OrganizationID != nil { + clauses = append(clauses, fmt.Sprintf("volumes.organization_id = $%d", len(args)+1)) + args = append(args, *filter.OrganizationID) + } if len(filter.RunnerIDs) > 0 { clauses = append(clauses, fmt.Sprintf("volumes.runner_id = ANY($%d)", len(args)+1)) diff --git a/internal/server/volumes_test.go b/internal/server/volumes_test.go index 5d99423..a563215 100644 --- a/internal/server/volumes_test.go +++ b/internal/server/volumes_test.go @@ -119,6 +119,75 @@ func TestListVolumesFiltersOrganization(t *testing.T) { } } +func TestListVolumesInternalNoIdentity(t *testing.T) { + mockPool, err := pgxmock.NewPool() + if err != nil { + t.Fatalf("failed to create mock pool: %v", err) + } + + volumeID := uuid.New() + volumeResourceID := uuid.New() + threadID := uuid.New() + runnerID := uuid.New() + agentID := uuid.New() + organizationID := uuid.New() + now := time.Now().UTC() + + rows := pgxmock.NewRows(volumeRowColumns). + AddRow(volumeID, nil, volumeResourceID, threadID, runnerID, agentID, organizationID, "10", volumeStatusActive, nil, nil, now, now) + volumeIDRows := pgxmock.NewRows([]string{"volume_id"}).AddRow(volumeResourceID) + volumeIDQuery := "SELECT DISTINCT volume_id FROM volumes" + mockPool.ExpectQuery(regexp.QuoteMeta(volumeIDQuery)). + WillReturnRows(volumeIDRows) + + volumeName := "volume-name" + limit := normalizePageSize(0) + sortExpr := "CASE volumes.volume_id WHEN $1 THEN $2 END" + query := fmt.Sprintf("SELECT %s FROM volumes ORDER BY %s ASC, volumes.id ASC LIMIT $3", volumeColumns, sortExpr) + mockPool.ExpectQuery(regexp.QuoteMeta(query)). + WithArgs(volumeResourceID, strings.ToLower(volumeName), int(limit)+1). + WillReturnRows(rows) + + agentsClient := fakeAgentsClient{ + getVolume: func(ctx context.Context, req *agentsv1.GetVolumeRequest) (*agentsv1.GetVolumeResponse, error) { + return &agentsv1.GetVolumeResponse{Volume: &agentsv1.Volume{Description: volumeName}}, nil + }, + listVolumeAttachments: func(ctx context.Context, req *agentsv1.ListVolumeAttachmentsRequest) (*agentsv1.ListVolumeAttachmentsResponse, error) { + return &agentsv1.ListVolumeAttachmentsResponse{}, nil + }, + } + + checkCalls := 0 + authorizationClient := fakeAuthorizationClient{ + check: func(ctx context.Context, req *authorizationv1.CheckRequest) (*authorizationv1.CheckResponse, error) { + checkCalls++ + return &authorizationv1.CheckResponse{Allowed: true}, nil + }, + } + + srv := New(Options{Pool: mockPool, AuthorizationClient: authorizationClient, AgentsClient: agentsClient}) + resp, err := srv.ListVolumes(context.Background(), &runnersv1.ListVolumesRequest{}) + if err != nil { + t.Fatalf("ListVolumes failed: %v", err) + } + if len(resp.GetVolumes()) != 1 { + t.Fatalf("expected 1 volume, got %d", len(resp.GetVolumes())) + } + if resp.GetVolumes()[0].GetOrganizationId() != organizationID.String() { + t.Fatalf("expected organization id %q, got %q", organizationID.String(), resp.GetVolumes()[0].GetOrganizationId()) + } + if resp.GetVolumes()[0].GetVolumeName() != volumeName { + t.Fatalf("expected volume name %q, got %q", volumeName, resp.GetVolumes()[0].GetVolumeName()) + } + if checkCalls != 0 { + t.Fatalf("expected no authorization checks, got %d", checkCalls) + } + + if err := mockPool.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + func TestListVolumesFiltersRunner(t *testing.T) { mockPool, err := pgxmock.NewPool() if err != nil { @@ -508,6 +577,10 @@ func TestListVolumesInvalidUUID(t *testing.T) { name string req *runnersv1.ListVolumesRequest }{ + { + name: "organization_id_missing", + req: &runnersv1.ListVolumesRequest{}, + }, { name: "organization_id", req: func() *runnersv1.ListVolumesRequest { @@ -617,6 +690,57 @@ func TestListVolumesRequiresViewVolumes(t *testing.T) { } } +func TestListVolumesByThreadInternalNoIdentity(t *testing.T) { + mockPool, err := pgxmock.NewPool() + if err != nil { + t.Fatalf("failed to create mock pool: %v", err) + } + + volumeID := uuid.New() + volumeResourceID := uuid.New() + threadID := uuid.New() + runnerID := uuid.New() + agentID := uuid.New() + organizationID := uuid.New() + now := time.Now().UTC() + limit := normalizePageSize(0) + + rows := pgxmock.NewRows(volumeRowColumns). + AddRow(volumeID, nil, volumeResourceID, threadID, runnerID, agentID, organizationID, "10", volumeStatusActive, nil, nil, now, now) + + query := fmt.Sprintf("SELECT %s FROM volumes WHERE thread_id = $1 ORDER BY id ASC LIMIT $2", volumeColumns) + mockPool.ExpectQuery(regexp.QuoteMeta(query)). + WithArgs(threadID, int(limit)+1). + WillReturnRows(rows) + + checkCalls := 0 + authorizationClient := fakeAuthorizationClient{ + check: func(ctx context.Context, req *authorizationv1.CheckRequest) (*authorizationv1.CheckResponse, error) { + checkCalls++ + return &authorizationv1.CheckResponse{Allowed: false}, nil + }, + } + + srv := New(Options{Pool: mockPool, AuthorizationClient: authorizationClient}) + resp, err := srv.ListVolumesByThread(context.Background(), &runnersv1.ListVolumesByThreadRequest{ThreadId: threadID.String()}) + if err != nil { + t.Fatalf("ListVolumesByThread failed: %v", err) + } + if len(resp.GetVolumes()) != 1 { + t.Fatalf("expected 1 volume, got %d", len(resp.GetVolumes())) + } + if resp.GetVolumes()[0].GetOrganizationId() != organizationID.String() { + t.Fatalf("expected organization id %q, got %q", organizationID.String(), resp.GetVolumes()[0].GetOrganizationId()) + } + if checkCalls != 0 { + t.Fatalf("expected no authorization checks, got %d", checkCalls) + } + + if err := mockPool.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + func TestGetVolumeRequiresViewVolumes(t *testing.T) { mockPool, err := pgxmock.NewPool() if err != nil { diff --git a/internal/server/workloads.go b/internal/server/workloads.go index a12b3b3..8e109d5 100644 --- a/internal/server/workloads.go +++ b/internal/server/workloads.go @@ -91,7 +91,7 @@ type workloadUpdateInput struct { } type workloadListFilter struct { - OrganizationID uuid.UUID + OrganizationID *uuid.UUID AgentIDs []uuid.UUID RunnerIDs []uuid.UUID Statuses []string @@ -429,7 +429,7 @@ func (s *Server) GetWorkload(ctx context.Context, req *runnersv1.GetWorkloadRequ } func (s *Server) ListWorkloadsByThread(ctx context.Context, req *runnersv1.ListWorkloadsByThreadRequest) (*runnersv1.ListWorkloadsByThreadResponse, error) { - callerID, err := identityFromMetadata(ctx) + callerID, err := identityFromMetadataOptional(ctx) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "unauthenticated: %v", err) } @@ -457,18 +457,21 @@ func (s *Server) ListWorkloadsByThread(ctx context.Context, req *runnersv1.ListW } return nil, status.Errorf(codes.Internal, "list workloads: %v", err) } - memberCache := map[uuid.UUID]bool{} - filtered := make([]workloadRecord, 0, len(workloads)) - for _, workload := range workloads { - allowed, err := s.memberAllowed(ctx, callerID, workload.OrganizationID, memberCache) - if err != nil { - return nil, err - } - if allowed { - filtered = append(filtered, workload) + if callerID != nil { + memberCache := map[uuid.UUID]bool{} + filtered := make([]workloadRecord, 0, len(workloads)) + for _, workload := range workloads { + allowed, err := s.memberAllowed(ctx, *callerID, workload.OrganizationID, memberCache) + if err != nil { + return nil, err + } + if allowed { + filtered = append(filtered, workload) + } } + workloads = filtered } - protoWorkloads, err := toProtoWorkloadList(filtered) + protoWorkloads, err := toProtoWorkloadList(workloads) if err != nil { return nil, status.Errorf(codes.Internal, "convert workloads: %v", err) } @@ -476,20 +479,30 @@ func (s *Server) ListWorkloadsByThread(ctx context.Context, req *runnersv1.ListW } func (s *Server) ListWorkloads(ctx context.Context, req *runnersv1.ListWorkloadsRequest) (*runnersv1.ListWorkloadsResponse, error) { - callerID, err := identityFromMetadata(ctx) + callerID, err := identityFromMetadataOptional(ctx) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "unauthenticated: %v", err) } orgValue := strings.TrimSpace(req.GetOrganizationId()) - if orgValue == "" { - return nil, status.Error(codes.InvalidArgument, "organization_id: value is empty") - } - organizationID, err := parseUUID(orgValue) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "organization_id: %v", err) - } - if err := s.requireRelation(ctx, callerID, organizationViewWorkloads, organizationObject(organizationID)); err != nil { - return nil, err + var organizationID *uuid.UUID + if callerID != nil { + if orgValue == "" { + return nil, status.Error(codes.InvalidArgument, "organization_id: value is empty") + } + parsed, err := parseUUID(orgValue) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "organization_id: %v", err) + } + organizationID = &parsed + if err := s.requireRelation(ctx, *callerID, organizationViewWorkloads, organizationObject(*organizationID)); err != nil { + return nil, err + } + } else if orgValue != "" { + parsed, err := parseUUID(orgValue) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "organization_id: %v", err) + } + organizationID = &parsed } filter := workloadListFilter{OrganizationID: organizationID} @@ -969,8 +982,12 @@ func buildWorkloadAgentSortExpr(agentNames map[uuid.UUID]string, startIndex int) func (s *Server) listWorkloads(ctx context.Context, filter workloadListFilter, sort workloadListSort, pageSize int32, pageToken string) ([]workloadRecord, string, error) { limit := normalizePageSize(pageSize) - clauses := []string{fmt.Sprintf("workloads.organization_id = $%d", 1)} - args := []any{filter.OrganizationID} + clauses := []string{} + args := []any{} + if filter.OrganizationID != nil { + clauses = append(clauses, fmt.Sprintf("workloads.organization_id = $%d", len(args)+1)) + args = append(args, *filter.OrganizationID) + } if len(filter.AgentIDs) > 0 { clauses = append(clauses, fmt.Sprintf("workloads.agent_id = ANY($%d)", len(args)+1)) diff --git a/internal/server/workloads_test.go b/internal/server/workloads_test.go index 496dc15..74fe875 100644 --- a/internal/server/workloads_test.go +++ b/internal/server/workloads_test.go @@ -137,6 +137,74 @@ func TestListWorkloadsFiltersOrganization(t *testing.T) { } } +func TestListWorkloadsInternalNoIdentity(t *testing.T) { + mockPool, err := pgxmock.NewPool() + if err != nil { + t.Fatalf("failed to create mock pool: %v", err) + } + + workloadID := uuid.New() + runnerID := uuid.New() + threadID := uuid.New() + agentID := uuid.New() + organizationID := uuid.New() + now := time.Now().UTC() + containersJSON := []byte("[]") + + rows := pgxmock.NewRows(workloadRowColumns). + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + + limit := normalizePageSize(0) + query := fmt.Sprintf("SELECT %s FROM workloads ORDER BY workloads.created_at DESC, workloads.id ASC LIMIT $1", workloadColumns) + mockPool.ExpectQuery(regexp.QuoteMeta(query)). + WithArgs(int(limit) + 1). + WillReturnRows(rows) + + runnerName := "runner-name" + runnerRows := pgxmock.NewRows([]string{"id", "name"}).AddRow(runnerID, runnerName) + mockPool.ExpectQuery(regexp.QuoteMeta("SELECT id, name FROM runners WHERE id = ANY($1)")). + WithArgs(pgtype.FlatArray[uuid.UUID]([]uuid.UUID{runnerID})). + WillReturnRows(runnerRows) + + agentName := "agent-name" + agentsClient := fakeAgentsClient{getAgent: func(ctx context.Context, req *agentsv1.GetAgentRequest) (*agentsv1.GetAgentResponse, error) { + return &agentsv1.GetAgentResponse{Agent: &agentsv1.Agent{Name: agentName}}, nil + }} + + checkCalls := 0 + authorizationClient := fakeAuthorizationClient{ + check: func(ctx context.Context, req *authorizationv1.CheckRequest) (*authorizationv1.CheckResponse, error) { + checkCalls++ + return &authorizationv1.CheckResponse{Allowed: true}, nil + }, + } + + srv := New(Options{Pool: mockPool, AuthorizationClient: authorizationClient, AgentsClient: agentsClient}) + resp, err := srv.ListWorkloads(context.Background(), &runnersv1.ListWorkloadsRequest{}) + if err != nil { + t.Fatalf("ListWorkloads failed: %v", err) + } + if len(resp.GetWorkloads()) != 1 { + t.Fatalf("expected 1 workload, got %d", len(resp.GetWorkloads())) + } + if resp.GetWorkloads()[0].GetOrganizationId() != organizationID.String() { + t.Fatalf("expected organization id %q, got %q", organizationID.String(), resp.GetWorkloads()[0].GetOrganizationId()) + } + if resp.GetWorkloads()[0].GetAgentName() != agentName { + t.Fatalf("expected agent name %q, got %q", agentName, resp.GetWorkloads()[0].GetAgentName()) + } + if resp.GetWorkloads()[0].GetRunnerName() != runnerName { + t.Fatalf("expected runner name %q, got %q", runnerName, resp.GetWorkloads()[0].GetRunnerName()) + } + if checkCalls != 0 { + t.Fatalf("expected no authorization checks, got %d", checkCalls) + } + + if err := mockPool.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + func TestListWorkloadsFiltersRunner(t *testing.T) { mockPool, err := pgxmock.NewPool() if err != nil { @@ -310,7 +378,7 @@ func TestListWorkloadsCursorPagination(t *testing.T) { WillReturnRows(rows) srv := New(Options{Pool: mockPool}) - filter := workloadListFilter{OrganizationID: organizationID} + filter := workloadListFilter{OrganizationID: &organizationID} sort, err := parseWorkloadSort(nil) if err != nil { t.Fatalf("parse sort: %v", err) @@ -374,7 +442,7 @@ func TestListWorkloadsSortByAgentQuery(t *testing.T) { } srv := New(Options{Pool: mockPool, AgentsClient: agentsClient}) - filter := workloadListFilter{OrganizationID: organizationID} + filter := workloadListFilter{OrganizationID: &organizationID} sort := workloadListSort{Field: workloadSortAgent, Direction: sortAsc} workloads, nextToken, err := srv.listWorkloads(context.Background(), filter, sort, pageSize, pageToken) if err != nil { @@ -424,7 +492,7 @@ func TestListWorkloadsSortByRunnerQuery(t *testing.T) { WillReturnRows(rows) srv := New(Options{Pool: mockPool}) - filter := workloadListFilter{OrganizationID: organizationID} + filter := workloadListFilter{OrganizationID: &organizationID} sort := workloadListSort{Field: workloadSortRunner, Direction: sortDesc} workloads, nextToken, err := srv.listWorkloads(context.Background(), filter, sort, pageSize, pageToken) if err != nil { @@ -456,6 +524,10 @@ func TestListWorkloadsInvalidUUID(t *testing.T) { name string req *runnersv1.ListWorkloadsRequest }{ + { + name: "organization_id_missing", + req: &runnersv1.ListWorkloadsRequest{}, + }, { name: "organization_id", req: func() *runnersv1.ListWorkloadsRequest { @@ -589,6 +661,57 @@ func TestListWorkloadsByThreadFilters(t *testing.T) { } } +func TestListWorkloadsByThreadInternalNoIdentity(t *testing.T) { + mockPool, err := pgxmock.NewPool() + if err != nil { + t.Fatalf("failed to create mock pool: %v", err) + } + + workloadID := uuid.New() + runnerID := uuid.New() + threadID := uuid.New() + agentID := uuid.New() + organizationID := uuid.New() + now := time.Now().UTC() + containersJSON := []byte("[]") + limit := normalizePageSize(0) + + rows := pgxmock.NewRows(workloadRowColumns). + AddRow(workloadID, runnerID, threadID, agentID, organizationID, workloadStatusRunning, nil, nil, containersJSON, "ziti-id", int32(0), int64(0), nil, now, nil, nil, now, now) + + query := fmt.Sprintf("SELECT %s FROM workloads WHERE thread_id = $1 ORDER BY created_at DESC, id DESC LIMIT $2", workloadColumns) + mockPool.ExpectQuery(regexp.QuoteMeta(query)). + WithArgs(threadID, int(limit)+1). + WillReturnRows(rows) + + checkCalls := 0 + authorizationClient := fakeAuthorizationClient{ + check: func(ctx context.Context, req *authorizationv1.CheckRequest) (*authorizationv1.CheckResponse, error) { + checkCalls++ + return &authorizationv1.CheckResponse{Allowed: false}, nil + }, + } + + srv := New(Options{Pool: mockPool, AuthorizationClient: authorizationClient}) + resp, err := srv.ListWorkloadsByThread(context.Background(), &runnersv1.ListWorkloadsByThreadRequest{ThreadId: threadID.String()}) + if err != nil { + t.Fatalf("ListWorkloadsByThread failed: %v", err) + } + if len(resp.GetWorkloads()) != 1 { + t.Fatalf("expected 1 workload, got %d", len(resp.GetWorkloads())) + } + if resp.GetWorkloads()[0].GetOrganizationId() != organizationID.String() { + t.Fatalf("expected organization id %q, got %q", organizationID.String(), resp.GetWorkloads()[0].GetOrganizationId()) + } + if checkCalls != 0 { + t.Fatalf("expected no authorization checks, got %d", checkCalls) + } + + if err := mockPool.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet expectations: %v", err) + } +} + func TestListWorkloadsByThreadPagination(t *testing.T) { mockPool, err := pgxmock.NewPool() if err != nil {