diff --git a/main.go b/main.go index a7743d5..15403bc 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "encoding/json" "errors" "fmt" "iter" @@ -26,7 +25,7 @@ import ( type AWSBudgetPlugin struct { Logger hclog.Logger - config *PluginConfig + config *PluginConfig awsBudgetClient *budgets.Client } @@ -34,14 +33,17 @@ type Validator interface { Validate() error } - type PluginConfig struct { - AccountId string `mapstructure:"account_id"` - AwsAccessKeyId string `mapstructure:"aws_access_key_id"` + AccountId string `mapstructure:"account_id"` + AwsAccessKeyId string `mapstructure:"aws_access_key_id"` AwsSecretAccessKey string `mapstructure:"aws_secret_access_key"` - AwsSessionToken string `mapstructure:"aws_session_token"` - AssumeRoleArn string `mapstructure:"assume_role_arn"` + AwsSessionToken string `mapstructure:"aws_session_token"` + AssumeRoleArn string `mapstructure:"assume_role_arn"` +} +type SaturatedBudget struct { + Budget *types.Budget + Alerts *[]types.Notification } func (c *PluginConfig) Validate() error { @@ -55,14 +57,14 @@ func (c *PluginConfig) Validate() error { func loadAWSConfig(ctx context.Context, pluginConfig *PluginConfig) (*aws.Config, error) { var awsConfig aws.Config var err error - + if pluginConfig.AwsAccessKeyId != "" && pluginConfig.AwsSecretAccessKey != "" && pluginConfig.AwsSessionToken != "" { // Use credentials if in config creds := aws.NewCredentialsCache( - credentials.NewStaticCredentialsProvider( - pluginConfig.AwsAccessKeyId, - pluginConfig.AwsSecretAccessKey, - pluginConfig.AwsSessionToken, + credentials.NewStaticCredentialsProvider( + pluginConfig.AwsAccessKeyId, + pluginConfig.AwsSecretAccessKey, + pluginConfig.AwsSessionToken, ), ) awsConfig, err = config.LoadDefaultConfig(ctx, config.WithRegion(os.Getenv("AWS_REGION")), config.WithCredentialsProvider(creds)) @@ -98,7 +100,6 @@ func loadAWSConfig(ctx context.Context, pluginConfig *PluginConfig) (*aws.Config return &awsConfig, nil } - func (l *AWSBudgetPlugin) Configure(req *proto.ConfigureRequest) (*proto.ConfigureResponse, error) { l.Logger.Info("Configuring AWS Budget Plugin") pluginConfig := &PluginConfig{} @@ -158,25 +159,20 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH break } - alertCount := 0 - - for _, err := range getNotificationsForBudget(ctx, l.awsBudgetClient, &l.config.AccountId, budget.BudgetName) { - if err != nil { - l.Logger.Error("unable to get notification", "error", err) - evalStatus = proto.ExecutionStatus_FAILURE - accumulatedErrors = errors.Join(accumulatedErrors, err) - break - } - alertCount += 1 + alerts, err := getNotificationsForBudget(ctx, l.awsBudgetClient, &l.config.AccountId, budget.BudgetName) + if err != nil { + l.Logger.Error("unable to get notifications", "error", err) + evalStatus = proto.ExecutionStatus_FAILURE + accumulatedErrors = errors.Join(accumulatedErrors, err) + break } labels := map[string]string{ - "provider": "aws", + "provider": "aws", "type": "budget", "account-id": l.config.AccountId, "budget-name": aws.ToString(budget.BudgetName), } - actors := []*proto.OriginActor{ { Title: "The Continuous Compliance Framework", @@ -225,11 +221,11 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH Value: aws.ToString(budget.BillingViewArn), }, { - Name: "alert-count", - Value: fmt.Sprintf("%v", alertCount), + Name: "alert-count", + Value: fmt.Sprintf("%v", len(*alerts)), }, { - Name: "health-status", + Name: "health-status", Value: aws.ToString((*string)(&budget.HealthStatus.Status)), }, }, @@ -253,10 +249,10 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH evidences := make([]*proto.Evidence, 0) - b, _ := json.Marshal(budget) - var budgetMap map[string]interface{} - _ = json.Unmarshal(b, &budgetMap) - budgetMap["AlertCount"] = alertCount + data := &SaturatedBudget{ + Budget: &budget, + Alerts: alerts, + } for _, policyPath := range request.GetPolicyPaths() { @@ -275,7 +271,8 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH actors, activities, ) - evidence, err := processor.GenerateResults(ctx, policyPath, budgetMap) + evidence, err := processor.GenerateResults(ctx, policyPath, data) + l.Logger.Info(fmt.Sprintf("Evidence: %v", evidence)) evidences = slices.Concat(evidences, evidence) if err != nil { accumulatedErrors = errors.Join(accumulatedErrors, err) @@ -290,7 +287,6 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH continue } - } return &proto.EvalResponse{ @@ -314,21 +310,13 @@ func getBudgets(ctx context.Context, client *budgets.Client, accountId *string) } } -func getNotificationsForBudget(ctx context.Context, client *budgets.Client, accountId *string, budgetName *string) iter.Seq2[types.Notification, error] { - return func(yield func(types.Notification, error) bool) { - result, err := client.DescribeNotificationsForBudget(ctx, &budgets.DescribeNotificationsForBudgetInput{AccountId: accountId, BudgetName: budgetName}) - if err != nil { - yield(types.Notification{}, err) - return - } - - for _, notification := range result.Notifications { - if !yield(notification, nil) { - return - } - } +func getNotificationsForBudget(ctx context.Context, client *budgets.Client, accountId *string, budgetName *string) (*[]types.Notification, error) { + result, err := client.DescribeNotificationsForBudget(ctx, &budgets.DescribeNotificationsForBudgetInput{AccountId: accountId, BudgetName: budgetName}) + if err != nil { + return nil, err } + return &result.Notifications, nil } func main() { @@ -352,4 +340,4 @@ func main() { }, GRPCServer: goplugin.DefaultGRPCServer, }) -} \ No newline at end of file +}