Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ LOG_DIR=logs
LOG_FILE_PREFIX=app
LOG_MAX_BODY_BYTES=1048576

# Bootstrap
BOOTSTRAP_ADMIN_TEAM=true
BOOTSTRAP_ADMIN_USER=true
BOOTSTRAP_ADMIN_USERNAME=admin
BOOTSTRAP_ADMIN_EMAIL=
BOOTSTRAP_ADMIN_PASSWORD=

# S3 Challenge Files
S3_ENABLED=false
S3_REGION=ap-northeast-2
Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ LOG_DIR=logs
LOG_FILE_PREFIX=app
LOG_MAX_BODY_BYTES=1048576

# Bootstrap
BOOTSTRAP_ADMIN_TEAM=true
BOOTSTRAP_ADMIN_USER=true
BOOTSTRAP_ADMIN_USERNAME=admin
BOOTSTRAP_ADMIN_EMAIL=
BOOTSTRAP_ADMIN_PASSWORD=

# S3 Challenge Files
S3_ENABLED=false
S3_REGION=ap-northeast-2
Expand Down
3 changes: 3 additions & 0 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"syscall"
"time"

"smctf/internal/bootstrap"
"smctf/internal/cache"
"smctf/internal/config"
"smctf/internal/db"
Expand Down Expand Up @@ -102,6 +103,8 @@ func main() {
stackClient := stack.NewClient(cfg.Stack.ProvisionerBaseURL, cfg.Stack.ProvisionerAPIKey, cfg.Stack.ProvisionerTimeout)
stackSvc := service.NewStackService(cfg.Stack, stackRepo, challengeRepo, submissionRepo, stackClient, redisClient)

bootstrap.BootstrapAdmin(ctx, cfg, database, userRepo, teamRepo, logger)

if cfg, _, _, err := appConfigSvc.Get(ctx); err != nil {
logger.Warn("app config load warning", slog.Any("error", err))
} else if cfg.CTFStartAt == "" && cfg.CTFEndAt == "" {
Expand Down
137 changes: 137 additions & 0 deletions internal/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package bootstrap

import (
"context"
"fmt"
"log/slog"
"strings"
"time"

"smctf/internal/auth"
"smctf/internal/config"
"smctf/internal/db"
"smctf/internal/logging"
"smctf/internal/models"
"smctf/internal/repo"

"github.com/uptrace/bun"
)

const (
bootstrapAdminTeamName = "Admin"
)

func BootstrapAdmin(ctx context.Context, cfg config.Config, database *bun.DB, userRepo *repo.UserRepo, teamRepo *repo.TeamRepo, logger *logging.Logger) {
if !cfg.Bootstrap.AdminTeamEnabled && !cfg.Bootstrap.AdminUserEnabled {
return
}

empty, err := isDatabaseEmpty(ctx, database)
if err != nil {
logger.Error("bootstrap database check error", slog.Any("error", err))
return
}

if !empty {
logger.Info("bootstrap skipped: database is not empty")
return
}

var team *models.Team

if cfg.Bootstrap.AdminTeamEnabled {
team, err = ensureAdminTeam(ctx, teamRepo)
if err != nil {
logger.Error("bootstrap admin team error", slog.Any("error", err))
return
}

if team != nil {
logger.Info("admin team created", slog.Any("team_id", team.ID), slog.Any("team_name", team.Name))
}
}

if team != nil && cfg.Bootstrap.AdminUserEnabled {
user, err := ensureAdminUser(ctx, cfg, team, userRepo)
if err != nil {
logger.Error("bootstrap admin user error", slog.Any("error", err))
return
}

if user != nil {
logger.Info("admin user created", slog.Any("user_id", user.ID))
}
}
}

func ensureAdminTeam(ctx context.Context, teamRepo *repo.TeamRepo) (*models.Team, error) {
team := &models.Team{
Name: bootstrapAdminTeamName,
CreatedAt: time.Now().UTC(),
}

if err := teamRepo.Create(ctx, team); err != nil {
if db.IsUniqueViolation(err) {
return nil, nil
}

return nil, fmt.Errorf("create team: %w", err)
}

return team, nil
}

func ensureAdminUser(ctx context.Context, cfg config.Config, team *models.Team, userRepo *repo.UserRepo) (*models.User, error) {
email := strings.TrimSpace(cfg.Bootstrap.AdminEmail)
password := strings.TrimSpace(cfg.Bootstrap.AdminPassword)
if email == "" || password == "" {
return nil, nil
}

username := strings.TrimSpace(cfg.Bootstrap.AdminUsername)
if username == "" {
username = "admin"
}

hash, err := auth.HashPassword(password, cfg.PasswordBcryptCost)
if err != nil {
return nil, fmt.Errorf("hash admin password: %w", err)
}

now := time.Now().UTC()
user := &models.User{
Email: email,
Username: username,
PasswordHash: hash,
Role: models.AdminRole,
TeamID: team.ID,
CreatedAt: now,
UpdatedAt: now,
}

if err := userRepo.Create(ctx, user); err != nil {
if db.IsUniqueViolation(err) {
return nil, nil
}

return nil, fmt.Errorf("create admin user: %w", err)
}

return user, nil
}

func isDatabaseEmpty(ctx context.Context, database *bun.DB) (bool, error) {
tables := []string{"users", "teams", "registration_keys"}
for _, table := range tables {
count, err := database.NewSelect().TableExpr(table).Count(ctx)
if err != nil {
return false, fmt.Errorf("count %s: %w", table, err)
}

if count > 0 {
return false, nil
}
}

return true, nil
}
Loading