From a66ee4ca42289853109b61707b9d006af23880df Mon Sep 17 00:00:00 2001 From: Jonathan Davies Date: Mon, 15 Sep 2025 17:42:22 +0100 Subject: [PATCH 1/2] fix: cleanup --- main.go | 73 ++++++++++++++++++++++++--------------------------------- 1 file changed, 30 insertions(+), 43 deletions(-) diff --git a/main.go b/main.go index a7743d5..9c7e4ae 100644 --- a/main.go +++ b/main.go @@ -26,7 +26,7 @@ import ( type AWSBudgetPlugin struct { Logger hclog.Logger - config *PluginConfig + config *PluginConfig awsBudgetClient *budgets.Client } @@ -34,14 +34,12 @@ type Validator interface { Validate() error } - type PluginConfig struct { - AccountId string `mapstructure:"account_id"` - AwsAccessKeyId string `mapstructure:"aws_access_key_id"` + AccountId string `mapstructure:"account_id"` + AwsAccessKeyId string `mapstructure:"aws_access_key_id"` AwsSecretAccessKey string `mapstructure:"aws_secret_access_key"` - AwsSessionToken string `mapstructure:"aws_session_token"` - AssumeRoleArn string `mapstructure:"assume_role_arn"` - + AwsSessionToken string `mapstructure:"aws_session_token"` + AssumeRoleArn string `mapstructure:"assume_role_arn"` } func (c *PluginConfig) Validate() error { @@ -55,14 +53,14 @@ func (c *PluginConfig) Validate() error { func loadAWSConfig(ctx context.Context, pluginConfig *PluginConfig) (*aws.Config, error) { var awsConfig aws.Config var err error - + if pluginConfig.AwsAccessKeyId != "" && pluginConfig.AwsSecretAccessKey != "" && pluginConfig.AwsSessionToken != "" { // Use credentials if in config creds := aws.NewCredentialsCache( - credentials.NewStaticCredentialsProvider( - pluginConfig.AwsAccessKeyId, - pluginConfig.AwsSecretAccessKey, - pluginConfig.AwsSessionToken, + credentials.NewStaticCredentialsProvider( + pluginConfig.AwsAccessKeyId, + pluginConfig.AwsSecretAccessKey, + pluginConfig.AwsSessionToken, ), ) awsConfig, err = config.LoadDefaultConfig(ctx, config.WithRegion(os.Getenv("AWS_REGION")), config.WithCredentialsProvider(creds)) @@ -98,7 +96,6 @@ func loadAWSConfig(ctx context.Context, pluginConfig *PluginConfig) (*aws.Config return &awsConfig, nil } - func (l *AWSBudgetPlugin) Configure(req *proto.ConfigureRequest) (*proto.ConfigureResponse, error) { l.Logger.Info("Configuring AWS Budget Plugin") pluginConfig := &PluginConfig{} @@ -158,20 +155,16 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH break } - alertCount := 0 - - for _, err := range getNotificationsForBudget(ctx, l.awsBudgetClient, &l.config.AccountId, budget.BudgetName) { - if err != nil { - l.Logger.Error("unable to get notification", "error", err) - evalStatus = proto.ExecutionStatus_FAILURE - accumulatedErrors = errors.Join(accumulatedErrors, err) - break - } - alertCount += 1 + alerts, err := getNotificationsForBudget(ctx, l.awsBudgetClient, &l.config.AccountId, budget.BudgetName) + if err != nil { + l.Logger.Error("unable to get notifications", "error", err) + evalStatus = proto.ExecutionStatus_FAILURE + accumulatedErrors = errors.Join(accumulatedErrors, err) + break } labels := map[string]string{ - "provider": "aws", + "provider": "aws", "type": "budget", "account-id": l.config.AccountId, "budget-name": aws.ToString(budget.BudgetName), @@ -225,11 +218,11 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH Value: aws.ToString(budget.BillingViewArn), }, { - Name: "alert-count", - Value: fmt.Sprintf("%v", alertCount), + Name: "alert-count", + Value: fmt.Sprintf("%v", len(*alerts)), }, { - Name: "health-status", + Name: "health-status", Value: aws.ToString((*string)(&budget.HealthStatus.Status)), }, }, @@ -256,7 +249,9 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH b, _ := json.Marshal(budget) var budgetMap map[string]interface{} _ = json.Unmarshal(b, &budgetMap) - budgetMap["AlertCount"] = alertCount + budgetMap["Alerts"] = alerts + + l.Logger.Info(fmt.Sprintf("Alerts: %v", alerts)) for _, policyPath := range request.GetPolicyPaths() { @@ -276,6 +271,7 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH activities, ) evidence, err := processor.GenerateResults(ctx, policyPath, budgetMap) + l.Logger.Info(fmt.Sprintf("Evidence: %v", evidence)) evidences = slices.Concat(evidences, evidence) if err != nil { accumulatedErrors = errors.Join(accumulatedErrors, err) @@ -290,7 +286,6 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH continue } - } return &proto.EvalResponse{ @@ -314,21 +309,13 @@ func getBudgets(ctx context.Context, client *budgets.Client, accountId *string) } } -func getNotificationsForBudget(ctx context.Context, client *budgets.Client, accountId *string, budgetName *string) iter.Seq2[types.Notification, error] { - return func(yield func(types.Notification, error) bool) { - result, err := client.DescribeNotificationsForBudget(ctx, &budgets.DescribeNotificationsForBudgetInput{AccountId: accountId, BudgetName: budgetName}) - if err != nil { - yield(types.Notification{}, err) - return - } - - for _, notification := range result.Notifications { - if !yield(notification, nil) { - return - } - } +func getNotificationsForBudget(ctx context.Context, client *budgets.Client, accountId *string, budgetName *string) (*[]types.Notification, error) { + result, err := client.DescribeNotificationsForBudget(ctx, &budgets.DescribeNotificationsForBudgetInput{AccountId: accountId, BudgetName: budgetName}) + if err != nil { + return nil, err } + return &result.Notifications, nil } func main() { @@ -352,4 +339,4 @@ func main() { }, GRPCServer: goplugin.DefaultGRPCServer, }) -} \ No newline at end of file +} From 650a064d034b8b47d900bddcf5a9a9ed7f0626c5 Mon Sep 17 00:00:00 2001 From: Jonathan Davies Date: Mon, 15 Sep 2025 17:52:25 +0100 Subject: [PATCH 2/2] fix: use struct for data --- main.go | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/main.go b/main.go index 9c7e4ae..15403bc 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "encoding/json" "errors" "fmt" "iter" @@ -42,6 +41,11 @@ type PluginConfig struct { AssumeRoleArn string `mapstructure:"assume_role_arn"` } +type SaturatedBudget struct { + Budget *types.Budget + Alerts *[]types.Notification +} + func (c *PluginConfig) Validate() error { if c.AccountId == "" { return fmt.Errorf("account_id is required") @@ -169,7 +173,6 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH "account-id": l.config.AccountId, "budget-name": aws.ToString(budget.BudgetName), } - actors := []*proto.OriginActor{ { Title: "The Continuous Compliance Framework", @@ -246,12 +249,10 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH evidences := make([]*proto.Evidence, 0) - b, _ := json.Marshal(budget) - var budgetMap map[string]interface{} - _ = json.Unmarshal(b, &budgetMap) - budgetMap["Alerts"] = alerts - - l.Logger.Info(fmt.Sprintf("Alerts: %v", alerts)) + data := &SaturatedBudget{ + Budget: &budget, + Alerts: alerts, + } for _, policyPath := range request.GetPolicyPaths() { @@ -270,7 +271,7 @@ func (l *AWSBudgetPlugin) Eval(request *proto.EvalRequest, apiHelper runner.ApiH actors, activities, ) - evidence, err := processor.GenerateResults(ctx, policyPath, budgetMap) + evidence, err := processor.GenerateResults(ctx, policyPath, data) l.Logger.Info(fmt.Sprintf("Evidence: %v", evidence)) evidences = slices.Concat(evidences, evidence) if err != nil {