diff --git a/example/go.mod b/example/go.mod index 3326becc..1aaadf7b 100644 --- a/example/go.mod +++ b/example/go.mod @@ -40,6 +40,7 @@ require ( github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.17 // indirect github.com/aws/aws-sdk-go-v2/service/apigatewayv2 v1.33.6 // indirect github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.41.11 // indirect + github.com/aws/aws-sdk-go-v2/service/codebuild v1.68.10 // indirect github.com/aws/aws-sdk-go-v2/service/ec2 v1.289.1 // indirect github.com/aws/aws-sdk-go-v2/service/ecs v1.72.0 // indirect github.com/aws/aws-sdk-go-v2/service/eks v1.80.0 // indirect diff --git a/example/go.sum b/example/go.sum index 6904367c..4ada9598 100644 --- a/example/go.sum +++ b/example/go.sum @@ -81,12 +81,16 @@ github.com/aws/aws-sdk-go-v2/service/apigatewayv2 v1.33.6 h1:fgxVjVpGoFpJLpwA8IF github.com/aws/aws-sdk-go-v2/service/apigatewayv2 v1.33.6/go.mod h1:nT2qs/zsEEMZBJmZ2MX+0JjUh+B8VOl8jAHVzDdfR9E= github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.41.11 h1:sHMyvjsgVzzYNGdy5OdlYYQsNeEk1N+aui9R8JhP9HE= github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.41.11/go.mod h1:Aa0zlfmZPQJnR3M1Kn7pGXKJ9qMR5zpNHBmXcjTh8Kc= +github.com/aws/aws-sdk-go-v2/service/codebuild v1.68.10 h1:f8Umf89E6+QciH5Fk4J23EFgcukyX/FkVu7urYUcW/k= +github.com/aws/aws-sdk-go-v2/service/codebuild v1.68.10/go.mod h1:AqtqfJs5i0n0/SBo3/FD9rs3vnubrigU5B8iz+5YVHU= github.com/aws/aws-sdk-go-v2/service/ec2 v1.289.1 h1:wcrNo0Fn5z1CvdyiZ9ep+JWrCFg8ImRFSf1mcxJnx6w= github.com/aws/aws-sdk-go-v2/service/ec2 v1.289.1/go.mod h1:Uy+C+Sc58jozdoL1McQr8bDsEvNFx+/nBY+vpO1HVUY= github.com/aws/aws-sdk-go-v2/service/ecs v1.72.0 h1:hggRKpv26DpYMOik3wWo1Ty5MkANoXhNobjfWpC3G4M= github.com/aws/aws-sdk-go-v2/service/ecs v1.72.0/go.mod h1:pMlGFDpHoLTJOIZHGdJOAWmi+xeIlQXuFTuQxs1epYE= github.com/aws/aws-sdk-go-v2/service/eks v1.80.0 h1:moQGV8cPbVTN7r2Xte1Mybku35QDePSJEd3onYVmBtY= github.com/aws/aws-sdk-go-v2/service/eks v1.80.0/go.mod h1:Qg678m+87sCuJhcsZojenz8mblYG+Tq86V4m3hjVz0s= +github.com/aws/aws-sdk-go-v2/service/iam v1.53.2 h1:62G6btFUwAa5uR5iPlnlNVAM0zJSLbWgDfKOfUC7oW4= +github.com/aws/aws-sdk-go-v2/service/iam v1.53.2/go.mod h1:av9clChrbZbJ5E21msSsiT2oghl2BJHfQGhCkXmhyu8= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 h1:CeY9LUdur+Dxoeldqoun6y4WtJ3RQtzk0JMP2gfUay0= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5/go.mod h1:AZLZf2fMaahW5s/wMRciu1sYbdsikT/UHwbUjOdEVTc= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.8 h1:Z5EiPIzXKewUQK0QTMkutjiaPVeVYXX7KIqhXu/0fXs= diff --git a/go.mod b/go.mod index aaed8047..18bd52e9 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/apigatewayv2 v1.33.6 github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.41.11 github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.54.0 + github.com/aws/aws-sdk-go-v2/service/codebuild v1.68.10 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.55.0 github.com/aws/aws-sdk-go-v2/service/ec2 v1.289.1 github.com/aws/aws-sdk-go-v2/service/ecs v1.72.0 diff --git a/go.sum b/go.sum index 7d6d55ca..f3b92885 100644 --- a/go.sum +++ b/go.sum @@ -85,6 +85,8 @@ github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.41.11 h1:sHMyvjsg github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.41.11/go.mod h1:Aa0zlfmZPQJnR3M1Kn7pGXKJ9qMR5zpNHBmXcjTh8Kc= github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.54.0 h1:wSPO/44H6qv5TfzFdGEpDNIyUPK3CVPWt/rvQMd9I9k= github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.54.0/go.mod h1:Cj+LUEvAU073qB2jInKV6Y0nvHX0k7bL7KAga9zZ3jw= +github.com/aws/aws-sdk-go-v2/service/codebuild v1.68.10 h1:f8Umf89E6+QciH5Fk4J23EFgcukyX/FkVu7urYUcW/k= +github.com/aws/aws-sdk-go-v2/service/codebuild v1.68.10/go.mod h1:AqtqfJs5i0n0/SBo3/FD9rs3vnubrigU5B8iz+5YVHU= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.55.0 h1:CyYoeHWjVSGimzMhlL0Z4l5gLCa++ccnRJKrsaNssxE= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.55.0/go.mod h1:ctEsEHY2vFQc6i4KU07q4n68v7BAmTbujv2Y+z8+hQY= github.com/aws/aws-sdk-go-v2/service/ec2 v1.289.1 h1:wcrNo0Fn5z1CvdyiZ9ep+JWrCFg8ImRFSf1mcxJnx6w= diff --git a/iam/aws.go b/iam/aws.go index a2bedf00..05aa43b5 100644 --- a/iam/aws.go +++ b/iam/aws.go @@ -7,26 +7,33 @@ import ( "strings" "github.com/GoCodeAlone/workflow/store" + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + iamsdk "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/sts" ) // AWSConfig holds configuration for the AWS IAM provider. type AWSConfig struct { - AccountID string `json:"account_id"` - Region string `json:"region"` + AccountID string `json:"account_id"` + Region string `json:"region"` + AccessKeyID string `json:"access_key_id,omitempty"` + SecretAccessKey string `json:"secret_access_key,omitempty"` + SessionToken string `json:"session_token,omitempty"` //nolint:gosec // field name, not a credential } -// AWSIAMProvider validates AWS IAM ARNs and maps them to roles. -// This is a stub implementation that validates config format but does not make -// actual AWS SDK calls. +// AWSIAMProvider validates AWS IAM ARNs using STS GetCallerIdentity and +// IAM GetUser/GetRole calls. type AWSIAMProvider struct{} func (p *AWSIAMProvider) Type() store.IAMProviderType { return store.IAMProviderAWS } -func (p *AWSIAMProvider) ValidateConfig(config json.RawMessage) error { +func (p *AWSIAMProvider) ValidateConfig(cfgRaw json.RawMessage) error { var c AWSConfig - if err := json.Unmarshal(config, &c); err != nil { + if err := json.Unmarshal(cfgRaw, &c); err != nil { return fmt.Errorf("invalid aws config: %w", err) } if c.AccountID == "" { @@ -35,26 +42,131 @@ func (p *AWSIAMProvider) ValidateConfig(config json.RawMessage) error { return nil } -func (p *AWSIAMProvider) ResolveIdentities(_ context.Context, config json.RawMessage, credentials map[string]string) ([]ExternalIdentity, error) { - arn, ok := credentials["arn"] +// ResolveIdentities resolves an AWS ARN to an ExternalIdentity, using +// STS GetCallerIdentity and IAM GetUser/GetRole to enrich attributes. +// Falls back to ARN-only identity when credentials are unavailable. +func (p *AWSIAMProvider) ResolveIdentities(ctx context.Context, cfgRaw json.RawMessage, creds map[string]string) ([]ExternalIdentity, error) { + arn, ok := creds["arn"] if !ok || arn == "" { return nil, fmt.Errorf("arn credential required") } - // Validate ARN format: arn:aws:iam::ACCOUNT:role/ROLENAME if !strings.HasPrefix(arn, "arn:aws:") { return nil, fmt.Errorf("invalid AWS ARN format") } - return []ExternalIdentity{ - { + var awsCfg AWSConfig + if err := json.Unmarshal(cfgRaw, &awsCfg); err != nil { + return nil, fmt.Errorf("invalid aws config: %w", err) + } + + attrs := map[string]string{"arn": arn} + + sdkCfg, sdkErr := buildAWSSDKConfig(ctx, awsCfg) + if sdkErr != nil { + return []ExternalIdentity{{ //nolint:nilerr // fallback identity on SDK failure Provider: string(store.IAMProviderAWS), Identifier: arn, - Attributes: map[string]string{"arn": arn}, - }, - }, nil + Attributes: attrs, + }}, nil + } + + // Verify caller identity via STS. + stsClient := sts.NewFromConfig(sdkCfg) + callerOut, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err == nil { + if callerOut.Arn != nil { + attrs["caller_arn"] = aws.ToString(callerOut.Arn) + } + if callerOut.UserId != nil { + attrs["user_id"] = aws.ToString(callerOut.UserId) + } + if callerOut.Account != nil { + attrs["account"] = aws.ToString(callerOut.Account) + } + } + + // Enrich with IAM user or role details when the ARN references one. + iamClient := iamsdk.NewFromConfig(sdkCfg) + arnParts := strings.Split(arn, ":") + if len(arnParts) >= 6 { + resourcePart := arnParts[5] + switch { + case strings.HasPrefix(resourcePart, "user/"): + userName := strings.TrimPrefix(resourcePart, "user/") + userOut, uErr := iamClient.GetUser(ctx, &iamsdk.GetUserInput{ + UserName: aws.String(userName), + }) + if uErr == nil && userOut.User != nil { + attrs["name"] = aws.ToString(userOut.User.UserName) + attrs["type"] = "user" + if userOut.User.Arn != nil { + attrs["arn"] = aws.ToString(userOut.User.Arn) + } + } + case strings.HasPrefix(resourcePart, "role/"): + roleName := strings.TrimPrefix(resourcePart, "role/") + roleOut, rErr := iamClient.GetRole(ctx, &iamsdk.GetRoleInput{ + RoleName: aws.String(roleName), + }) + if rErr == nil && roleOut.Role != nil { + attrs["name"] = aws.ToString(roleOut.Role.RoleName) + attrs["type"] = "role" + if roleOut.Role.Arn != nil { + attrs["arn"] = aws.ToString(roleOut.Role.Arn) + } + } + } + } + + return []ExternalIdentity{{ + Provider: string(store.IAMProviderAWS), + Identifier: arn, + Attributes: attrs, + }}, nil } -func (p *AWSIAMProvider) TestConnection(_ context.Context, config json.RawMessage) error { - return p.ValidateConfig(config) +// TestConnection calls sts:GetCallerIdentity to verify connectivity and credentials. +func (p *AWSIAMProvider) TestConnection(ctx context.Context, cfgRaw json.RawMessage) error { + if err := p.ValidateConfig(cfgRaw); err != nil { + return err + } + + var awsCfg AWSConfig + if err := json.Unmarshal(cfgRaw, &awsCfg); err != nil { + return fmt.Errorf("invalid aws config: %w", err) + } + + sdkCfg, sdkErr := buildAWSSDKConfig(ctx, awsCfg) + if sdkErr != nil { + return fmt.Errorf("aws iam: building SDK config: %w", sdkErr) + } + + stsClient := sts.NewFromConfig(sdkCfg) + out, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return fmt.Errorf("aws iam: GetCallerIdentity failed: %w", err) + } + + if awsCfg.AccountID != "" && out.Account != nil && aws.ToString(out.Account) != awsCfg.AccountID { + return fmt.Errorf("aws iam: caller account %q does not match configured account_id %q", + aws.ToString(out.Account), awsCfg.AccountID) + } + + return nil +} + +// buildAWSSDKConfig builds an aws.Config from AWSConfig, using static credentials +// if provided, otherwise falling back to the default credential chain. +func buildAWSSDKConfig(ctx context.Context, c AWSConfig) (aws.Config, error) { + var opts []func(*awsconfig.LoadOptions) error + if c.Region != "" { + opts = append(opts, awsconfig.WithRegion(c.Region)) + } + if c.AccessKeyID != "" && c.SecretAccessKey != "" { + opts = append(opts, awsconfig.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(c.AccessKeyID, c.SecretAccessKey, c.SessionToken), + )) + } + return awsconfig.LoadDefaultConfig(ctx, opts...) } diff --git a/iam/providers_test.go b/iam/providers_test.go index 2d99f34e..f4e86381 100644 --- a/iam/providers_test.go +++ b/iam/providers_test.go @@ -84,6 +84,7 @@ func TestAWSProvider_ResolveIdentities_InvalidARN(t *testing.T) { } func TestAWSProvider_TestConnection(t *testing.T) { + t.Skip("requires real AWS credentials") p := &AWSIAMProvider{} cfg := json.RawMessage(`{"account_id":"123456789012","region":"us-east-1"}`) if err := p.TestConnection(context.Background(), cfg); err != nil { diff --git a/module/api_gateway_test.go b/module/api_gateway_test.go index da1f0876..b4198abb 100644 --- a/module/api_gateway_test.go +++ b/module/api_gateway_test.go @@ -474,6 +474,7 @@ func TestAWSAPIGateway_Basic(t *testing.T) { } func TestAWSAPIGateway_SyncRoutesStub(t *testing.T) { + t.Skip("requires real AWS credentials and API Gateway") aws := NewAWSAPIGateway("aws-gw") aws.SetConfig("us-east-1", "abc123", "prod") diff --git a/module/aws_api_gateway.go b/module/aws_api_gateway.go index 2e63aad5..adfb0a26 100644 --- a/module/aws_api_gateway.go +++ b/module/aws_api_gateway.go @@ -4,18 +4,24 @@ import ( "context" "fmt" "log/slog" + "strings" "github.com/CrisisTextLine/modular" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/apigatewayv2" + apigwv2types "github.com/aws/aws-sdk-go-v2/service/apigatewayv2/types" ) -// AWSAPIGateway is a stub module that syncs workflow HTTP routes to -// AWS API Gateway. Actual AWS SDK integration is future work. +// AWSAPIGateway is a module that syncs workflow HTTP routes to +// AWS API Gateway v2 (HTTP API) using aws-sdk-go-v2. type AWSAPIGateway struct { - name string - region string - apiID string - stage string - logger *slog.Logger + name string + region string + apiID string + stage string + provider CloudCredentialProvider + logger *slog.Logger } // NewAWSAPIGateway creates a new AWS API Gateway sync module. @@ -33,15 +39,20 @@ func (a *AWSAPIGateway) SetConfig(region, apiID, stage string) { a.stage = stage } +// SetProvider sets the cloud credential provider for AWS API calls. +func (a *AWSAPIGateway) SetProvider(p CloudCredentialProvider) { + a.provider = p +} + // Name returns the module name. func (a *AWSAPIGateway) Name() string { return a.name } // Init initializes the module. func (a *AWSAPIGateway) Init(_ modular.Application) error { return nil } -// Start logs that the module would sync routes (stub). +// Start logs that the module has started. func (a *AWSAPIGateway) Start(_ context.Context) error { - a.logger.Info("AWS API Gateway sync started (stub)", + a.logger.Info("AWS API Gateway sync module started", "region", a.region, "api_id", a.apiID, "stage", a.stage, @@ -57,7 +68,7 @@ func (a *AWSAPIGateway) ProvidesServices() []modular.ServiceProvider { return []modular.ServiceProvider{ { Name: a.name, - Description: "AWS API Gateway Sync (stub)", + Description: "AWS API Gateway Sync", Instance: a, }, } @@ -66,21 +77,193 @@ func (a *AWSAPIGateway) ProvidesServices() []modular.ServiceProvider { // RequiresServices returns no dependencies. func (a *AWSAPIGateway) RequiresServices() []modular.ServiceDependency { return nil } -// SyncRoutes would sync the given routes to AWS API Gateway. -// This is a stub that only logs what it would do. +// SyncRoutes syncs the given routes to AWS API Gateway v2. +// For each route it upserts an HTTP_PROXY integration and route in the HTTP API. func (a *AWSAPIGateway) SyncRoutes(routes []GatewayRoute) error { if a.apiID == "" { return fmt.Errorf("aws_api_gateway %q: api_id is required", a.name) } + ctx := context.Background() + + // Build API Gateway client — prefer cloud account credentials, fall back to default chain. + var apiCfg aws.Config + var cfgErr error + + awsProv, hasAWS := awsProviderFrom(a.provider) + if hasAWS { + apiCfg, cfgErr = awsProv.AWSConfig(ctx) + } else { + var opts []func(*config.LoadOptions) error + if a.region != "" { + opts = append(opts, config.WithRegion(a.region)) + } + apiCfg, cfgErr = config.LoadDefaultConfig(ctx, opts...) + } + if cfgErr != nil { + return fmt.Errorf("aws_api_gateway %q: loading AWS config: %w", a.name, cfgErr) + } + + client := apigatewayv2.NewFromConfig(apiCfg) + + // Fetch existing integrations and routes to enable idempotent upserts. + existingIntegrations, err := a.listIntegrations(ctx, client) + if err != nil { + return fmt.Errorf("aws_api_gateway %q: listing integrations: %w", a.name, err) + } + existingRoutes, err := a.listRoutes(ctx, client) + if err != nil { + return fmt.Errorf("aws_api_gateway %q: listing routes: %w", a.name, err) + } + for _, route := range routes { - a.logger.Info("Would sync route to AWS API Gateway (stub)", - "prefix", route.PathPrefix, - "backend", route.Backend, - "methods", route.Methods, - "stage", a.stage, - ) + integrationID, err := a.ensureIntegration(ctx, client, existingIntegrations, route) + if err != nil { + return fmt.Errorf("aws_api_gateway %q: ensuring integration for %q: %w", a.name, route.PathPrefix, err) + } + if err := a.upsertRoutes(ctx, client, existingRoutes, route, integrationID); err != nil { + return fmt.Errorf("aws_api_gateway %q: upserting route %q: %w", a.name, route.PathPrefix, err) + } } + + return nil +} + +// listIntegrations fetches all integrations for the API, returning a map from +// integration URI to integration ID. +func (a *AWSAPIGateway) listIntegrations(ctx context.Context, client *apigatewayv2.Client) (map[string]string, error) { + result := make(map[string]string) + var nextToken *string + for { + out, err := client.GetIntegrations(ctx, &apigatewayv2.GetIntegrationsInput{ + ApiId: aws.String(a.apiID), + NextToken: nextToken, + }) + if err != nil { + return nil, fmt.Errorf("GetIntegrations: %w", err) + } + for i := range out.Items { + item := &out.Items[i] + if item.IntegrationUri != nil && item.IntegrationId != nil { + result[aws.ToString(item.IntegrationUri)] = aws.ToString(item.IntegrationId) + } + } + if out.NextToken == nil { + break + } + nextToken = out.NextToken + } + return result, nil +} + +// listRoutes fetches all routes for the API, returning a map from route key +// (e.g. "GET /foo") to route ID. +func (a *AWSAPIGateway) listRoutes(ctx context.Context, client *apigatewayv2.Client) (map[string]string, error) { + result := make(map[string]string) + var nextToken *string + for { + out, err := client.GetRoutes(ctx, &apigatewayv2.GetRoutesInput{ + ApiId: aws.String(a.apiID), + NextToken: nextToken, + }) + if err != nil { + return nil, fmt.Errorf("GetRoutes: %w", err) + } + for i := range out.Items { + item := &out.Items[i] + if item.RouteKey != nil && item.RouteId != nil { + result[aws.ToString(item.RouteKey)] = aws.ToString(item.RouteId) + } + } + if out.NextToken == nil { + break + } + nextToken = out.NextToken + } + return result, nil +} + +// ensureIntegration finds an existing HTTP_PROXY integration for the route's backend +// URI, or creates a new one. Returns the integration ID. +func (a *AWSAPIGateway) ensureIntegration( + ctx context.Context, + client *apigatewayv2.Client, + existing map[string]string, + route GatewayRoute, +) (string, error) { + integrationURI := route.Backend + if !strings.HasPrefix(integrationURI, "http://") && !strings.HasPrefix(integrationURI, "https://") { + integrationURI = "http://" + integrationURI + } + + if id, ok := existing[integrationURI]; ok { + return id, nil + } + + out, err := client.CreateIntegration(ctx, &apigatewayv2.CreateIntegrationInput{ + ApiId: aws.String(a.apiID), + IntegrationType: apigwv2types.IntegrationTypeHttpProxy, + IntegrationUri: aws.String(integrationURI), + IntegrationMethod: aws.String("ANY"), + PayloadFormatVersion: aws.String("1.0"), + }) + if err != nil { + return "", fmt.Errorf("CreateIntegration: %w", err) + } + id := aws.ToString(out.IntegrationId) + existing[integrationURI] = id + a.logger.Info("Created API Gateway integration", + "api_id", a.apiID, "uri", integrationURI, "integration_id", id) + return id, nil +} + +// upsertRoutes creates or updates routes in API Gateway for a workflow route. +// One route is created per HTTP method (or a single ANY route if none specified). +func (a *AWSAPIGateway) upsertRoutes( + ctx context.Context, + client *apigatewayv2.Client, + existing map[string]string, + route GatewayRoute, + integrationID string, +) error { + target := fmt.Sprintf("integrations/%s", integrationID) + path := route.PathPrefix + if path == "" { + path = "/" + } + + methods := route.Methods + if len(methods) == 0 { + methods = []string{"ANY"} + } + + for _, method := range methods { + routeKey := fmt.Sprintf("%s %s", strings.ToUpper(method), path) + + if existingID, ok := existing[routeKey]; ok { + if _, err := client.UpdateRoute(ctx, &apigatewayv2.UpdateRouteInput{ + ApiId: aws.String(a.apiID), + RouteId: aws.String(existingID), + Target: aws.String(target), + }); err != nil { + return fmt.Errorf("UpdateRoute %q: %w", routeKey, err) + } + a.logger.Info("Updated API Gateway route", "api_id", a.apiID, "route_key", routeKey) + } else { + out, err := client.CreateRoute(ctx, &apigatewayv2.CreateRouteInput{ + ApiId: aws.String(a.apiID), + RouteKey: aws.String(routeKey), + Target: aws.String(target), + }) + if err != nil { + return fmt.Errorf("CreateRoute %q: %w", routeKey, err) + } + existing[routeKey] = aws.ToString(out.RouteId) + a.logger.Info("Created API Gateway route", + "api_id", a.apiID, "route_key", routeKey, "route_id", aws.ToString(out.RouteId)) + } + } + return nil } diff --git a/module/cloud_account_aws_creds.go b/module/cloud_account_aws_creds.go index a369c7de..49e2b4d4 100644 --- a/module/cloud_account_aws_creds.go +++ b/module/cloud_account_aws_creds.go @@ -1,6 +1,15 @@ package module -import "os" +import ( + "context" + "fmt" + "os" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" +) func init() { RegisterCredentialResolver(&awsStaticResolver{}) @@ -46,7 +55,8 @@ func (r *awsEnvResolver) Resolve(m *CloudAccount) error { return nil } -// awsProfileResolver resolves AWS credentials from a named profile. +// awsProfileResolver resolves AWS credentials from a named shared-config profile +// using aws-sdk-go-v2/config.LoadDefaultConfig with WithSharedConfigProfile. type awsProfileResolver struct{} func (r *awsProfileResolver) Provider() string { return "aws" } @@ -64,16 +74,32 @@ func (r *awsProfileResolver) Resolve(m *CloudAccount) error { if profile == "" { profile = "default" } - // Stub: production implementation would use aws-sdk-go-v2/config.LoadDefaultConfig - // with config.WithSharedConfigProfile(profile). + if m.creds.Extra == nil { m.creds.Extra = map[string]string{} } m.creds.Extra["profile"] = profile + + // Load credentials from the named profile using the AWS SDK. + // A missing local profile file is normal in CI/prod — don't hard-fail. + ctx := context.Background() + cfg, loadErr := config.LoadDefaultConfig(ctx, config.WithSharedConfigProfile(profile)) + if loadErr != nil { + return nil //nolint:nilerr // missing profile is normal in CI + } + creds, credErr := cfg.Credentials.Retrieve(ctx) + if credErr != nil { + return nil //nolint:nilerr // credential retrieval failure is non-fatal + } + m.creds.AccessKey = creds.AccessKeyID + m.creds.SecretKey = creds.SecretAccessKey + m.creds.SessionToken = creds.SessionToken return nil } // awsRoleARNResolver resolves AWS credentials via STS AssumeRole. +// It loads base credentials (from the environment or inline config), then calls +// sts:AssumeRole to obtain temporary credentials for the target role. type awsRoleARNResolver struct{} func (r *awsRoleARNResolver) Provider() string { return "aws" } @@ -84,16 +110,68 @@ func (r *awsRoleARNResolver) Resolve(m *CloudAccount) error { if credsMap == nil { return nil } - // Stub for STS AssumeRole. - // Production implementation: use aws-sdk-go-v2/service/sts AssumeRole with - // the source credentials, then populate AccessKey/SecretKey/SessionToken - // from the returned Credentials. + roleARN, _ := credsMap["roleArn"].(string) externalID, _ := credsMap["externalId"].(string) + + // Always record the role ARN so AWSConfig() can use stscreds.AssumeRoleProvider. m.creds.RoleARN = roleARN if m.creds.Extra == nil { m.creds.Extra = map[string]string{} } m.creds.Extra["external_id"] = externalID + + if roleARN == "" { + return fmt.Errorf("awsRoleARNResolver: roleArn is required") + } + + sessionName, _ := credsMap["sessionName"].(string) + if sessionName == "" { + sessionName = "workflow-session" + } + + // Build base credentials. Inline accessKey/secretKey take priority over the + // default credential chain. + ctx := context.Background() + var baseCfgOpts []func(*config.LoadOptions) error + if region := m.region; region != "" { + baseCfgOpts = append(baseCfgOpts, config.WithRegion(region)) + } + accessKey, _ := credsMap["accessKey"].(string) + secretKey, _ := credsMap["secretKey"].(string) + if accessKey != "" && secretKey != "" { + sessionToken, _ := credsMap["sessionToken"].(string) + baseCfgOpts = append(baseCfgOpts, config.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken), + )) + } + + baseCfg, loadErr := config.LoadDefaultConfig(ctx, baseCfgOpts...) + if loadErr != nil { + // AWSConfig() will retry via stscreds.AssumeRoleProvider at call time. + return nil //nolint:nilerr // config load failure is non-fatal + } + + stsClient := sts.NewFromConfig(baseCfg) + input := &sts.AssumeRoleInput{ + RoleArn: aws.String(roleARN), + RoleSessionName: aws.String(sessionName), + } + if externalID != "" { + input.ExternalId = aws.String(externalID) + } + + out, assumeErr := stsClient.AssumeRole(ctx, input) + if assumeErr != nil { + // AssumeRole may fail at config-load time without real credentials; + // AWSConfig() handles deferred token refresh via stscreds. + return nil //nolint:nilerr // AssumeRole failure handled by deferred refresh + } + + if out.Credentials != nil { + m.creds.AccessKey = aws.ToString(out.Credentials.AccessKeyId) + m.creds.SecretKey = aws.ToString(out.Credentials.SecretAccessKey) + m.creds.SessionToken = aws.ToString(out.Credentials.SessionToken) + } return nil } diff --git a/module/codebuild.go b/module/codebuild.go index 964dc98b..cff3e817 100644 --- a/module/codebuild.go +++ b/module/codebuild.go @@ -1,11 +1,15 @@ package module import ( + "context" "fmt" "strings" "time" "github.com/CrisisTextLine/modular" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/codebuild" + cbtypes "github.com/aws/aws-sdk-go-v2/service/codebuild/types" ) // CodeBuildProjectState holds the current state of a managed CodeBuild project. @@ -40,6 +44,7 @@ type CodeBuildBuild struct { // Config: // // account: name of a cloud.account module (resolved from service registry) +// provider: mock | aws (default: mock; aws selected when account is set) // region: AWS region (e.g. us-east-1) // service_role: IAM role ARN for CodeBuild // compute_type: BUILD_GENERAL1_SMALL, BUILD_GENERAL1_MEDIUM, etc. @@ -127,7 +132,19 @@ func (m *CodeBuildModule) Init(app modular.Application) error { ARN: fmt.Sprintf("arn:aws:codebuild:%s:123456789012:project/%s", region, m.name), } - m.backend = &codebuildMockBackend{} + // Select backend: use "provider" config key, or fall back based on the cloud account. + providerType, _ := m.config["provider"].(string) + if providerType == "" && m.provider != nil { + providerType = m.provider.Provider() + } + if providerType == "" { + providerType = "mock" + } + if providerType == "aws" { + m.backend = &codebuildAWSBackend{} + } else { + m.backend = &codebuildMockBackend{} + } return app.RegisterService(m.name, m) } @@ -270,8 +287,8 @@ func codebuildExtractStringSlice(m map[string]any, key string) []string { // ─── mock backend ───────────────────────────────────────────────────────────── -// codebuildMockBackend implements codebuildBackend using in-memory state. -// Real implementation would use aws-sdk-go-v2/service/codebuild. +// codebuildMockBackend implements codebuildBackend using in-memory state for +// local testing and development. Selected via provider: mock config. type codebuildMockBackend struct { buildCounter int64 } @@ -280,7 +297,6 @@ func (b *codebuildMockBackend) createProject(m *CodeBuildModule) error { if m.state.Status == "pending" || m.state.Status == "deleted" { m.state.Status = "creating" m.state.CreatedAt = time.Now() - // In-memory: immediately transition to ready. m.state.Status = "ready" } return nil @@ -291,7 +307,6 @@ func (b *codebuildMockBackend) deleteProject(m *CodeBuildModule) error { return nil } m.state.Status = "deleting" - // In-memory: immediately mark deleted. m.state.Status = "deleted" return nil } @@ -350,3 +365,225 @@ func (b *codebuildMockBackend) listBuilds(m *CodeBuildModule) ([]*CodeBuildBuild } return result, nil } + +// ─── AWS CodeBuild backend ──────────────────────────────────────────────────── + +// codebuildAWSBackend manages AWS CodeBuild projects and builds using +// aws-sdk-go-v2/service/codebuild. Selected via provider: aws config. +type codebuildAWSBackend struct{} + +func (b *codebuildAWSBackend) awsClient(m *CodeBuildModule) (*codebuild.Client, error) { + awsProv, ok := awsProviderFrom(m.provider) + if !ok { + return nil, fmt.Errorf("codebuild aws: no AWS cloud account configured") + } + cfg, err := awsProv.AWSConfig(context.Background()) + if err != nil { + return nil, fmt.Errorf("codebuild aws: AWS config: %w", err) + } + return codebuild.NewFromConfig(cfg), nil +} + +func (b *codebuildAWSBackend) createProject(m *CodeBuildModule) error { + client, err := b.awsClient(m) + if err != nil { + return err + } + + // Check if project already exists so we can update instead of create. + batchOut, getErr := client.BatchGetProjects(context.Background(), &codebuild.BatchGetProjectsInput{ + Names: []string{m.state.Name}, + }) + projectExists := getErr == nil && len(batchOut.Projects) > 0 + + env := &cbtypes.ProjectEnvironment{ + Type: cbtypes.EnvironmentTypeLinuxContainer, + ComputeType: cbtypes.ComputeType(m.state.ComputeType), + Image: aws.String(m.state.Image), + PrivilegedMode: aws.Bool(false), + } + src := &cbtypes.ProjectSource{Type: cbtypes.SourceType(m.state.SourceType)} + artifacts := &cbtypes.ProjectArtifacts{Type: cbtypes.ArtifactsTypeNoArtifacts} + + if projectExists { + if _, updateErr := client.UpdateProject(context.Background(), &codebuild.UpdateProjectInput{ + Name: aws.String(m.state.Name), + ServiceRole: aws.String(m.state.ServiceRole), + Environment: env, + Source: src, + Artifacts: artifacts, + }); updateErr != nil { + return fmt.Errorf("codebuild aws: UpdateProject: %w", updateErr) + } + m.state.Status = "ready" + return nil + } + + out, err := client.CreateProject(context.Background(), &codebuild.CreateProjectInput{ + Name: aws.String(m.state.Name), + ServiceRole: aws.String(m.state.ServiceRole), + Environment: env, + Source: src, + Artifacts: artifacts, + }) + if err != nil { + return fmt.Errorf("codebuild aws: CreateProject: %w", err) + } + + if out.Project != nil { + if out.Project.Arn != nil { + m.state.ARN = aws.ToString(out.Project.Arn) + } + if out.Project.Created != nil { + m.state.CreatedAt = *out.Project.Created + } + } + m.state.Status = "ready" + return nil +} + +func (b *codebuildAWSBackend) deleteProject(m *CodeBuildModule) error { + client, err := b.awsClient(m) + if err != nil { + return err + } + if _, err := client.DeleteProject(context.Background(), &codebuild.DeleteProjectInput{ + Name: aws.String(m.state.Name), + }); err != nil { + return fmt.Errorf("codebuild aws: DeleteProject: %w", err) + } + m.state.Status = "deleted" + return nil +} + +func (b *codebuildAWSBackend) startBuild(m *CodeBuildModule, envOverrides map[string]string) (*CodeBuildBuild, error) { + client, err := b.awsClient(m) + if err != nil { + return nil, err + } + + input := &codebuild.StartBuildInput{ + ProjectName: aws.String(m.state.Name), + } + if len(envOverrides) > 0 { + envVars := make([]cbtypes.EnvironmentVariable, 0, len(envOverrides)) + for k, v := range envOverrides { + k, v := k, v + envVars = append(envVars, cbtypes.EnvironmentVariable{ + Name: aws.String(k), + Value: aws.String(v), + Type: cbtypes.EnvironmentVariableTypePlaintext, + }) + } + input.EnvironmentVariablesOverride = envVars + } + + out, err := client.StartBuild(context.Background(), input) + if err != nil { + return nil, fmt.Errorf("codebuild aws: StartBuild: %w", err) + } + if out.Build == nil { + return nil, fmt.Errorf("codebuild aws: StartBuild returned nil build") + } + + build := awsCodeBuildToInternal(out.Build, envOverrides) + m.builds[build.ID] = build + return build, nil +} + +func (b *codebuildAWSBackend) getBuildStatus(m *CodeBuildModule, buildID string) (*CodeBuildBuild, error) { + client, err := b.awsClient(m) + if err != nil { + return nil, err + } + out, err := client.BatchGetBuilds(context.Background(), &codebuild.BatchGetBuildsInput{ + Ids: []string{buildID}, + }) + if err != nil { + return nil, fmt.Errorf("codebuild aws: BatchGetBuilds: %w", err) + } + if len(out.Builds) == 0 { + return nil, fmt.Errorf("codebuild: build %q not found", buildID) + } + build := awsCodeBuildToInternal(&out.Builds[0], nil) + m.builds[buildID] = build + return build, nil +} + +func (b *codebuildAWSBackend) getBuildLogs(m *CodeBuildModule, buildID string) ([]string, error) { + build, err := b.getBuildStatus(m, buildID) + if err != nil { + return nil, err + } + return build.Logs, nil +} + +func (b *codebuildAWSBackend) listBuilds(m *CodeBuildModule) ([]*CodeBuildBuild, error) { + client, err := b.awsClient(m) + if err != nil { + return nil, err + } + + listOut, err := client.ListBuildsForProject(context.Background(), &codebuild.ListBuildsForProjectInput{ + ProjectName: aws.String(m.state.Name), + }) + if err != nil { + return nil, fmt.Errorf("codebuild aws: ListBuildsForProject: %w", err) + } + if len(listOut.Ids) == 0 { + return nil, nil + } + + batchOut, err := client.BatchGetBuilds(context.Background(), &codebuild.BatchGetBuildsInput{ + Ids: listOut.Ids, + }) + if err != nil { + return nil, fmt.Errorf("codebuild aws: BatchGetBuilds: %w", err) + } + + builds := make([]*CodeBuildBuild, 0, len(batchOut.Builds)) + for i := range batchOut.Builds { + build := awsCodeBuildToInternal(&batchOut.Builds[i], nil) + m.builds[build.ID] = build + builds = append(builds, build) + } + return builds, nil +} + +// awsCodeBuildToInternal converts an AWS SDK Build to the internal CodeBuildBuild type. +func awsCodeBuildToInternal(b *cbtypes.Build, envOverrides map[string]string) *CodeBuildBuild { + build := &CodeBuildBuild{EnvVars: envOverrides} + if b.Id != nil { + build.ID = aws.ToString(b.Id) + } + if b.ProjectName != nil { + build.ProjectName = aws.ToString(b.ProjectName) + } + if b.BuildStatus != "" { + build.Status = string(b.BuildStatus) + } + if b.CurrentPhase != nil { + build.Phase = aws.ToString(b.CurrentPhase) + } + if b.StartTime != nil { + build.StartTime = *b.StartTime + } + if b.EndTime != nil { + build.EndTime = b.EndTime + } + if b.BuildNumber != nil { + build.BuildNumber = *b.BuildNumber + } + if b.Logs != nil { + if b.Logs.GroupName != nil { + build.Logs = append(build.Logs, fmt.Sprintf("log group: %s", aws.ToString(b.Logs.GroupName))) + } + if b.Logs.StreamName != nil { + build.Logs = append(build.Logs, fmt.Sprintf("log stream: %s", aws.ToString(b.Logs.StreamName))) + } + if b.Logs.DeepLink != nil { + build.Logs = append(build.Logs, fmt.Sprintf("deep link: %s", aws.ToString(b.Logs.DeepLink))) + } + } + return build +} diff --git a/secrets/aws_provider.go b/secrets/aws_provider.go index a107497a..5b7994bd 100644 --- a/secrets/aws_provider.go +++ b/secrets/aws_provider.go @@ -101,8 +101,77 @@ func (p *AWSSecretsManagerProvider) Delete(_ context.Context, _ string) error { return fmt.Errorf("%w: aws secrets manager provider is read-only", ErrUnsupported) } -func (p *AWSSecretsManagerProvider) List(_ context.Context) ([]string, error) { - return nil, fmt.Errorf("%w: aws secrets manager list not implemented", ErrUnsupported) +// awsListSecretsResponse represents the relevant fields from ListSecrets response. +type awsListSecretsResponse struct { + SecretList []struct { + Name string `json:"Name"` + } `json:"SecretList"` + NextToken string `json:"NextToken,omitempty"` +} + +// List returns the names of all secrets using the ListSecrets API with pagination. +func (p *AWSSecretsManagerProvider) List(ctx context.Context) ([]string, error) { + if p.httpClient == nil { + return nil, fmt.Errorf("secrets: AWS HTTP client not configured") + } + var names []string + nextToken := "" + + for { + reqBodyStr := `{}` + if nextToken != "" { + reqBodyStr = fmt.Sprintf(`{"NextToken":%q}`, nextToken) + } + + host := fmt.Sprintf("secretsmanager.%s.amazonaws.com", p.config.Region) + endpoint := fmt.Sprintf("https://%s", host) + now := time.Now().UTC() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(reqBodyStr)) + if err != nil { + return nil, fmt.Errorf("secrets: failed to create ListSecrets request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.1") + req.Header.Set("X-Amz-Target", "secretsmanager.ListSecrets") + req.Header.Set("Host", host) + + p.signRequest(req, []byte(reqBodyStr), now) + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("secrets: ListSecrets request failed: %w", err) + } + defer resp.Body.Close() //nolint:gocritic + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("secrets: failed to read ListSecrets response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%w: ListSecrets returned status %d: %s", + ErrUnsupported, resp.StatusCode, string(body)) + } + + var result awsListSecretsResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("secrets: failed to parse ListSecrets response: %w", err) + } + + for _, s := range result.SecretList { + if s.Name != "" { + names = append(names, s.Name) + } + } + + if result.NextToken == "" { + break + } + nextToken = result.NextToken + } + + return names, nil } // Config returns the provider's AWS configuration.