diff --git a/backend/internal/activities/deployer_activities.go b/backend/internal/activities/deployer_activities.go index 7b5ead6b..8dabac38 100644 --- a/backend/internal/activities/deployer_activities.go +++ b/backend/internal/activities/deployer_activities.go @@ -711,35 +711,6 @@ func registerDeploymentActivities(engine *ewf.Engine, metrics *metrics.Metrics, engine.RegisterTemplate(constants.WorkflowRollbackFailedAddNode, &rollbackAddNodeWFTemplate) } -func getFromState[T any](state ewf.State, key string) (T, error) { - value, ok := state[key] - if !ok { - var zero T - return zero, fmt.Errorf("missing '%s' in state", key) - } - - // Try direct type assertion first (for newly created values) - if val, ok := value.(T); ok { - return val, nil - } - - // Handle the case where value was serialized/deserialized and became a map - // Use JSON marshaling/unmarshaling to convert map to struct - valueBytes, err := json.Marshal(value) - if err != nil { - var zero T - return zero, fmt.Errorf("failed to marshal %s value: %w", key, err) - } - - var result T - if err := json.Unmarshal(valueBytes, &result); err != nil { - var zero T - return zero, fmt.Errorf("failed to unmarshal %s: %w", key, err) - } - - return result, nil -} - func getConfig(state ewf.State) (statemanager.ClientConfig, error) { value, ok := state["config"] if !ok { diff --git a/backend/internal/activities/node_activities.go b/backend/internal/activities/node_activities.go index bf832306..39793b1f 100644 --- a/backend/internal/activities/node_activities.go +++ b/backend/internal/activities/node_activities.go @@ -14,15 +14,11 @@ import ( func CreateIdentityStep() ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - - mnemonicVal, ok := state["mnemonic"] - if !ok { - return fmt.Errorf("missing 'mnemonic' in state") - } - mnemonic, ok := mnemonicVal.(string) - if !ok { - return fmt.Errorf("'mnemonic' in state is not a string") + mnemonic, err := getFromState[string](state, "mnemonic") + if err != nil { + return err } + identity, err := substrate.NewIdentityFromSr25519Phrase(mnemonic) if err != nil { return fmt.Errorf("failed to create identity: %w", err) @@ -34,17 +30,19 @@ func CreateIdentityStep() ewf.StepFn { func ReserveNodeStep(db models.DB, substrateClient *substrate.Substrate) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - userID, ok := state["user_id"].(int) - if !ok { - return fmt.Errorf("missing or invalid 'user_id' in state") + userID, err := getFromState[int](state, "user_id") + if err != nil { + return err } - nodeID, ok := state["node_id"].(uint32) - if !ok { - return fmt.Errorf("missing or invalid 'node_id' in state") + + nodeID, err := getFromState[uint32](state, "node_id") + if err != nil { + return err } - identity, ok := state["identity"].(substrate.Identity) - if !ok { - return fmt.Errorf("missing or invalid 'identity' in state") + + identity, err := getFromState[substrate.Identity](state, "identity") + if err != nil { + return err } // Reserve the node @@ -70,13 +68,14 @@ func ReserveNodeStep(db models.DB, substrateClient *substrate.Substrate) ewf.Ste func UnreserveNodeStep(db models.DB, substrateClient *substrate.Substrate) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - contractID, ok := state["contract_id"].(uint64) - if !ok { - return fmt.Errorf("missing or invalid 'contract_id' in state") + contractID, err := getFromState[uint64](state, "contract_id") + if err != nil { + return err } - mnemonic, ok := state["mnemonic"].(string) - if !ok { - return fmt.Errorf("missing or invalid 'mnemonic' in state") + + mnemonic, err := getFromState[string](state, "mnemonic") + if err != nil { + return err } identity, err := substrate.NewIdentityFromSr25519Phrase(mnemonic) @@ -101,23 +100,17 @@ func UnreserveNodeStep(db models.DB, substrateClient *substrate.Substrate) ewf.S // VerifyNodeStateStep checks if node has reached the desired state func VerifyNodeStateStep(proxyClient proxy.Client) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - - targetStatus, ok := state["target_status"].(string) - if !ok { - return fmt.Errorf("missing or invalid 'target_status' in state") - } - - nodeID, exists := state["node_id"] - if !exists { - return fmt.Errorf("missing or invalid 'node_id' in state") + targetStatus, err := getFromState[string](state, "target_status") + if err != nil { + return err } - nodeIDUint32, ok := nodeID.(uint32) - if !ok { - return fmt.Errorf("node_id in state is not a uint32") + nodeID, err := getFromState[uint32](state, "node_id") + if err != nil { + return err } - node, err := proxyClient.Node(ctx, nodeIDUint32) + node, err := proxyClient.Node(ctx, nodeID) if err != nil { return fmt.Errorf("failed to get node: %w", err) } @@ -125,7 +118,7 @@ func VerifyNodeStateStep(proxyClient proxy.Client) ewf.StepFn { reached := targetStatus == constants.NodeRentable && node.Rentable || targetStatus == constants.NodeRented && !node.Rentable if !reached { - return fmt.Errorf("node %d has not reached target status '%s' (current: rentable=%v)", nodeIDUint32, targetStatus, node.Rentable) + return fmt.Errorf("node %d has not reached target status '%s' (current: rentable=%v)", nodeID, targetStatus, node.Rentable) } return nil diff --git a/backend/internal/activities/notification_activities.go b/backend/internal/activities/notification_activities.go index f8a321eb..8e1d2caf 100644 --- a/backend/internal/activities/notification_activities.go +++ b/backend/internal/activities/notification_activities.go @@ -17,18 +17,22 @@ func SendNotification(db models.DB, notifier notification.Notifier) ewf.StepFn { if !ok { return fmt.Errorf("missing notification in workflow state") } + notif, ok := raw.(*models.Notification) if !ok || notif == nil { return fmt.Errorf("invalid notification in workflow state") } + if !slices.Contains(notif.Channels, notifier.GetType()) { logger.GetLogger().Debug().Msgf("SendNotification: step skipped for channel %s (not in notification channels)", notifier.GetType()) return nil } + user, err := db.GetUserByID(notif.UserID) if err != nil { return fmt.Errorf("failed to get user by ID (id: %v): %w", notif.UserID, err) } + if err := notifier.Notify(*notif, user.Email); err != nil { return fmt.Errorf("failed to send notification (id: %v) to %s: %w", notif.ID, notifier.GetType(), err) } diff --git a/backend/internal/activities/state_helpers.go b/backend/internal/activities/state_helpers.go new file mode 100644 index 00000000..9f8839d1 --- /dev/null +++ b/backend/internal/activities/state_helpers.go @@ -0,0 +1,39 @@ +package activities + +import ( + "encoding/json" + "fmt" + + "github.com/xmonader/ewf" +) + +// getFromState is a generic helper function to extract and type-cast values from workflow state. +// It handles both direct type assertions and JSON-based conversions for serialized/deserialized values. +func getFromState[T any](state ewf.State, key string) (T, error) { + value, ok := state[key] + if !ok { + var zero T + return zero, fmt.Errorf("missing '%s' in state", key) + } + + // Try direct type assertion first (for newly created values) + if val, ok := value.(T); ok { + return val, nil + } + + // Handle the case where value was serialized/deserialized and became a map + // Use JSON marshaling/unmarshaling to convert map to struct + valueBytes, err := json.Marshal(value) + if err != nil { + var zero T + return zero, fmt.Errorf("failed to marshal %s value: %w", key, err) + } + + var result T + if err := json.Unmarshal(valueBytes, &result); err != nil { + var zero T + return zero, fmt.Errorf("failed to unmarshal %s: %w", key, err) + } + + return result, nil +} diff --git a/backend/internal/activities/user_activities.go b/backend/internal/activities/user_activities.go index 62e43c33..aef8a1ef 100644 --- a/backend/internal/activities/user_activities.go +++ b/backend/internal/activities/user_activities.go @@ -19,31 +19,19 @@ import ( func CreateUserStep(config internal.Configuration, db models.DB) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - emailVal, ok := state["email"] - if !ok { - return fmt.Errorf("missing 'email' in state") - } - email, ok := emailVal.(string) - if !ok { - return fmt.Errorf("'email' in state is not a string") + email, err := getFromState[string](state, "email") + if err != nil { + return err } - nameVal, ok := state["name"] - if !ok { - return fmt.Errorf("missing 'name' in state") - } - name, ok := nameVal.(string) - if !ok { - return fmt.Errorf("'name' in state is not a string") + name, err := getFromState[string](state, "name") + if err != nil { + return err } - passwordVal, ok := state["password"] - if !ok { - return fmt.Errorf("missing 'password' in state") - } - password, ok := passwordVal.(string) - if !ok { - return fmt.Errorf("'password' in state is not a string") + password, err := getFromState[string](state, "password") + if err != nil { + return err } hashedPassword, err := internal.HashAndSaltPassword([]byte(password)) @@ -81,22 +69,14 @@ func CreateUserStep(config internal.Configuration, db models.DB) ewf.StepFn { func SendVerificationEmailStep(mailService internal.MailService, config internal.Configuration) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - emailVal, ok := state["email"] - if !ok { - return fmt.Errorf("missing 'email' in state") - } - email, ok := emailVal.(string) - if !ok { - return fmt.Errorf("'email' in state is not a string") + email, err := getFromState[string](state, "email") + if err != nil { + return err } - nameVal, ok := state["name"] - if !ok { - return fmt.Errorf("missing 'name' in state") - } - name, ok := nameVal.(string) - if !ok { - return fmt.Errorf("'name' in state is not a string") + name, err := getFromState[string](state, "name") + if err != nil { + return err } code := internal.GenerateRandomCode() @@ -113,22 +93,14 @@ func SendVerificationEmailStep(mailService internal.MailService, config internal func UpdateCodeStep(db models.DB) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - emailVal, ok := state["email"] - if !ok { - return fmt.Errorf("missing 'email' in state") - } - email, ok := emailVal.(string) - if !ok { - return fmt.Errorf("'email' in state is not a string") + email, err := getFromState[string](state, "email") + if err != nil { + return err } - codeVal, ok := state["code"] - if !ok { - return fmt.Errorf("missing 'code' in state") - } - code, ok := codeVal.(int) - if !ok { - return fmt.Errorf("'code' in state is not a int") + code, err := getFromState[int](state, "code") + if err != nil { + return err } existingUser, err := db.GetUserByEmail(email) @@ -143,13 +115,9 @@ func UpdateCodeStep(db models.DB) ewf.StepFn { func SetupTFChainStep(client *substrate.Substrate, config internal.Configuration, notificationService *notification.NotificationService, db models.DB) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - userIDVal, ok := state["user_id"] - if !ok { - return fmt.Errorf("missing 'user_id' in state") - } - userID, ok := userIDVal.(int) - if !ok { - return fmt.Errorf("'user_id' in state is not an int") + userID, err := getFromState[int](state, "user_id") + if err != nil { + return err } existingUser, err := db.GetUserByID(userID) @@ -181,13 +149,9 @@ func SetupTFChainStep(client *substrate.Substrate, config internal.Configuration func CreateStripeCustomerStep(db models.DB) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - userIDVal, ok := state["user_id"] - if !ok { - return fmt.Errorf("missing 'user_id' in state") - } - userID, ok := userIDVal.(int) - if !ok { - return fmt.Errorf("'user_id' in state is not an int") + userID, err := getFromState[int](state, "user_id") + if err != nil { + return err } existingUser, err := db.GetUserByID(userID) @@ -199,22 +163,14 @@ func CreateStripeCustomerStep(db models.DB) ewf.StepFn { return nil } - emailVal, ok := state["email"] - if !ok { - return fmt.Errorf("missing 'email' in state") - } - email, ok := emailVal.(string) - if !ok { - return fmt.Errorf("'email' in state is not a string") + email, err := getFromState[string](state, "email") + if err != nil { + return err } - nameVal, ok := state["name"] - if !ok { - return fmt.Errorf("missing 'name' in state") - } - name, ok := nameVal.(string) - if !ok { - return fmt.Errorf("'name' in state is not a string") + name, err := getFromState[string](state, "name") + if err != nil { + return err } customer, err := internal.CreateStripeCustomer(name, email) @@ -235,13 +191,9 @@ func CreateStripeCustomerStep(db models.DB) ewf.StepFn { func CreateKYCSponsorship(kycClient *internal.KYCClient, notificationService *notification.NotificationService, sponsorAddress string, sponsorKeyPair subkey.KeyPair, db models.DB) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - userIDVal, ok := state["user_id"] - if !ok { - return fmt.Errorf("missing 'user_id' in state") - } - userID, ok := userIDVal.(int) - if !ok { - return fmt.Errorf("'user_id' in state is not an int") + userID, err := getFromState[int](state, "user_id") + if err != nil { + return err } existingUser, err := db.GetUserByID(userID) @@ -253,13 +205,9 @@ func CreateKYCSponsorship(kycClient *internal.KYCClient, notificationService *no return nil } - mnemonicVal, ok := state["mnemonic"] - if !ok { - return fmt.Errorf("missing 'mnemonic' in state") - } - mnemonic, ok := mnemonicVal.(string) - if !ok { - return fmt.Errorf("'mnemonic' in state is not a string") + mnemonic, err := getFromState[string](state, "mnemonic") + if err != nil { + return err } // Set user.AccountAddress from mnemonic @@ -295,22 +243,14 @@ func SendWelcomeEmailStep(mailService internal.MailService, config internal.Conf return func(ctx context.Context, state ewf.State) error { metrics.IncrementUserRegistration() - emailVal, ok := state["email"] - if !ok { - return fmt.Errorf("missing 'email' in state") - } - email, ok := emailVal.(string) - if !ok { - return fmt.Errorf("'email' in state is not a string") + email, err := getFromState[string](state, "email") + if err != nil { + return err } - nameVal, ok := state["name"] - if !ok { - return fmt.Errorf("missing 'name' in state") - } - name, ok := nameVal.(string) - if !ok { - return fmt.Errorf("'name' in state is not a string") + name, err := getFromState[string](state, "name") + if err != nil { + return err } subject, body := mailService.WelcomeMailContent(name, config.Server.Host) @@ -323,29 +263,19 @@ func SendWelcomeEmailStep(mailService internal.MailService, config internal.Conf func CreatePaymentIntentStep(currency string, metrics *metrics.Metrics, notificationService *notification.NotificationService) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - customerIDVal, ok := state["stripe_customer_id"] - if !ok { - return fmt.Errorf("missing 'stripe_customer_id' in state") - } - customerID, ok := customerIDVal.(string) - if !ok { - return fmt.Errorf("'stripe_customer_id' in state is not a string") - } - paymentMethodIDVal, ok := state["payment_method_id"] - if !ok { - return fmt.Errorf("missing 'payment_method_id' in state") - } - paymentMethodID, ok := paymentMethodIDVal.(string) - if !ok { - return fmt.Errorf("'payment_method_id' in state is not a string") + customerID, err := getFromState[string](state, "stripe_customer_id") + if err != nil { + return err } - amountVal, ok := state["amount"] - if !ok { - return fmt.Errorf("missing 'amount' in state") + + paymentMethodID, err := getFromState[string](state, "payment_method_id") + if err != nil { + return err } - amount, ok := amountVal.(uint64) - if !ok { - return fmt.Errorf("'amount' in state is not a uint64") + + amount, err := getFromState[uint64](state, "amount") + if err != nil { + return err } intent, err := internal.CreatePaymentIntent(customerID, paymentMethodID, currency, amount) @@ -362,41 +292,24 @@ func CreatePaymentIntentStep(currency string, metrics *metrics.Metrics, notifica func CreatePendingRecord(substrateClient *substrate.Substrate, db models.DB, systemMnemonic string) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - amountVal, ok := state["amount"] - if !ok { - return fmt.Errorf("missing 'amount' in state") - } - - amount, ok := amountVal.(uint64) - if !ok { - return fmt.Errorf("'amount' in state is not a uint64") + amount, err := getFromState[uint64](state, "amount") + if err != nil { + return err } - userIDVal, ok := state["user_id"] - if !ok { - return fmt.Errorf("missing 'user_id' in state") - } - userID, ok := userIDVal.(int) - if !ok { - return fmt.Errorf("'user_id' in state is not an int") + userID, err := getFromState[int](state, "user_id") + if err != nil { + return err } - usernameVal, ok := state["username"] - if !ok { - return fmt.Errorf("missing 'username' in state") - } - username, ok := usernameVal.(string) - if !ok { - return fmt.Errorf("'username' in state is not a string") + username, err := getFromState[string](state, "username") + if err != nil { + return err } - transferModeVal, ok := state["transfer_mode"] - if !ok { - return fmt.Errorf("missing 'transfer_mode' in state") - } - transferMode, ok := transferModeVal.(string) - if !ok { - return fmt.Errorf("'transfer_mode' in state is not a string") + transferMode, err := getFromState[string](state, "transfer_mode") + if err != nil { + return err } requestedTFTs, err := internal.FromUSDMillicentToTFT(substrateClient, amount) @@ -421,22 +334,14 @@ func CreatePendingRecord(substrateClient *substrate.Substrate, db models.DB, sys func UpdateCreditCardBalanceStep(db models.DB) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - userIDVal, ok := state["user_id"] - if !ok { - return fmt.Errorf("missing 'user_id' in state") - } - userID, ok := userIDVal.(int) - if !ok { - return fmt.Errorf("'user_id' in state is not an int") + userID, err := getFromState[int](state, "user_id") + if err != nil { + return err } - amountVal, ok := state["amount"] - if !ok { - return fmt.Errorf("missing 'amount' in state") - } - amount, ok := amountVal.(uint64) - if !ok { - return fmt.Errorf("'amount' in state is not a uint64") + amount, err := getFromState[uint64](state, "amount") + if err != nil { + return err } user, err := db.GetUserByID(userID) @@ -462,22 +367,14 @@ func UpdateCreditCardBalanceStep(db models.DB) ewf.StepFn { func UpdateCreditedBalanceStep(db models.DB) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { - userIDVal, ok := state["user_id"] - if !ok { - return fmt.Errorf("missing 'user_id' in state") - } - userID, ok := userIDVal.(int) - if !ok { - return fmt.Errorf("'user_id' in state is not an int") + userID, err := getFromState[int](state, "user_id") + if err != nil { + return err } - amountVal, ok := state["amount"] - if !ok { - return fmt.Errorf("missing 'amount' in state") - } - amount, ok := amountVal.(uint64) - if !ok { - return fmt.Errorf("'amount' in state is not a uint64") + amount, err := getFromState[uint64](state, "amount") + if err != nil { + return err } user, err := db.GetUserByID(userID)