diff --git a/internal/cli/start.go b/internal/cli/start.go index 7b789f7..bf06665 100644 --- a/internal/cli/start.go +++ b/internal/cli/start.go @@ -25,11 +25,16 @@ func (c *CLI) newStartCommand() *cobra.Command { log.Fatal().Msgf("%v", err) } - if len(templateID) == 0 { - if err := cmd.Help(); err != nil { - log.Fatal().Msgf("%v", err) - } - return + tagsStr, err := cmd.Flags().GetString("tags") + if err != nil { + log.Fatal().Msgf("%v", err) + } + + if templateID == "" && tagsStr == "" { + log.Fatal().Msg("either --id or --tags must be provided") + } + if templateID != "" && tagsStr != "" { + log.Fatal().Msg("--id and --tags are mutually exclusive") } provider, ok := c.app.GetProvider(providerName) @@ -37,24 +42,44 @@ func (c *CLI) newStartCommand() *cobra.Command { log.Fatal().Msgf("provider %s not found", providerName) } - template, err := tmpl.GetByID(c.app.Templates, templateID) - if err != nil { - log.Fatal().Msgf("%v", err) - } + if templateID != "" { + template, err := tmpl.GetByID(c.app.Templates, templateID) + if err != nil { + log.Fatal().Msgf("%v", err) + } - err = provider.Start(template) - if err != nil { - log.Fatal().Msgf("%v", err) - } + c.startTemplate(provider, template, providerName) + } else { + tags := strings.Split(tagsStr, ",") + templates, err := tmpl.GetByTags(c.app.Templates, tags) + if err != nil { + log.Fatal().Msgf("%v", err) + } + + log.Info().Msgf("found %d templates matching tags: %s", len(templates), tagsStr) + + var failed []string + for _, template := range templates { + if err := provider.Start(template); err != nil { + log.Error().Msgf("failed to start %s: %v", template.ID, err) + failed = append(failed, template.ID) + continue + } - if len(template.PostInstall) > 0 { - log.Info().Msg("Post-installation instructions:") - for _, instruction := range template.PostInstall { - fmt.Printf(" %s\n", instruction) + if len(template.PostInstall) > 0 { + log.Info().Msgf("Post-installation instructions for %s:", template.ID) + for _, instruction := range template.PostInstall { + fmt.Printf(" %s\n", instruction) + } + } + + log.Info().Msgf("%s template is running on %s", template.ID, providerName) } - } - log.Info().Msgf("%s template is running on %s", templateID, providerName) + if len(failed) > 0 { + log.Warn().Msgf("failed to start %d templates: %s", len(failed), strings.Join(failed, ", ")) + } + } }, } @@ -65,13 +90,25 @@ func (c *CLI) newStartCommand() *cobra.Command { cmd.Flags().String("id", "", "Specify a template ID for targeted vulnerable environment") - if err := cmd.MarkFlagRequired("provider"); err != nil { + cmd.Flags().StringP("tags", "t", "", + "Specify comma-separated tags to start all matching templates (e.g., --tags sqli,xss)") + + return cmd +} + +// startTemplate starts a single template and logs the result. +func (c *CLI) startTemplate(provider interface{ Start(*tmpl.Template) error }, template *tmpl.Template, providerName string) { + err := provider.Start(template) + if err != nil { log.Fatal().Msgf("%v", err) } - if err := cmd.MarkFlagRequired("id"); err != nil { - log.Fatal().Msgf("%v", err) + if len(template.PostInstall) > 0 { + log.Info().Msg("Post-installation instructions:") + for _, instruction := range template.PostInstall { + fmt.Printf(" %s\n", instruction) + } } - return cmd + log.Info().Msgf("%s template is running on %s", template.ID, providerName) } diff --git a/internal/cli/stop.go b/internal/cli/stop.go index 1f3ddf3..4362608 100644 --- a/internal/cli/stop.go +++ b/internal/cli/stop.go @@ -13,7 +13,7 @@ import ( func (c *CLI) newStopCommand() *cobra.Command { cmd := &cobra.Command{ Use: "stop", - Short: "Stop vulnerable environment by template id and provider", + Short: "Stop vulnerable environment by template id or tags", Run: func(cmd *cobra.Command, _ []string) { providerName, err := cmd.Flags().GetString("provider") if err != nil { @@ -25,22 +25,63 @@ func (c *CLI) newStopCommand() *cobra.Command { log.Fatal().Msgf("%v", err) } + tagsStr, err := cmd.Flags().GetString("tags") + if err != nil { + log.Fatal().Msgf("%v", err) + } + + if templateID == "" && tagsStr == "" { + log.Fatal().Msg("either --id or --tags must be provided") + } + if templateID != "" && tagsStr != "" { + log.Fatal().Msg("--id and --tags are mutually exclusive") + } + provider, ok := c.app.GetProvider(providerName) if !ok { log.Fatal().Msgf("provider %s not found", providerName) } - template, err := tmpl.GetByID(c.app.Templates, templateID) - if err != nil { - log.Fatal().Msgf("%v", err) - } + if templateID != "" { + template, err := tmpl.GetByID(c.app.Templates, templateID) + if err != nil { + log.Fatal().Msgf("%v", err) + } - err = provider.Stop(template) - if err != nil { - log.Fatal().Msgf("%v", err) - } + err = provider.Stop(template) + if err != nil { + log.Fatal().Msgf("%v", err) + } + + log.Info().Msgf("%s template stopped on %s", templateID, providerName) + } else { + tags := strings.Split(tagsStr, ",") + templates, err := tmpl.GetByTags(c.app.Templates, tags) + if err != nil { + log.Fatal().Msgf("%v", err) + } + + log.Info().Msgf("found %d templates matching tags: %s", len(templates), tagsStr) - log.Info().Msgf("%s template stopped on %s", templateID, providerName) + var failed []string + var stopped int + for _, template := range templates { + if err := provider.Stop(template); err != nil { + log.Error().Msgf("failed to stop %s: %v", template.ID, err) + failed = append(failed, template.ID) + continue + } + stopped++ + log.Info().Msgf("%s template stopped on %s", template.ID, providerName) + } + + if len(failed) > 0 { + log.Warn().Msgf("failed to stop %d templates: %s", len(failed), strings.Join(failed, ", ")) + } + if stopped > 0 { + log.Info().Msgf("successfully stopped %d templates", stopped) + } + } }, } @@ -51,13 +92,8 @@ func (c *CLI) newStopCommand() *cobra.Command { cmd.Flags().String("id", "", "Specify a template ID for targeted vulnerable environment") - if err := cmd.MarkFlagRequired("provider"); err != nil { - log.Fatal().Msgf("%v", err) - } - - if err := cmd.MarkFlagRequired("id"); err != nil { - log.Fatal().Msgf("%v", err) - } + cmd.Flags().StringP("tags", "t", "", + "Specify comma-separated tags to stop all matching templates (e.g., --tags sqli,xss)") return cmd } diff --git a/pkg/template/template.go b/pkg/template/template.go index 7116449..2a3f758 100644 --- a/pkg/template/template.go +++ b/pkg/template/template.go @@ -302,6 +302,46 @@ func GetByID(templates map[string]Template, templateID string) (*Template, error return &tmpl, nil } +// GetByTags retrieves all templates that match any of the given tags. +// Tags are matched case-insensitively and support substring matching. +// Returns an error if no templates match the given tags. +func GetByTags(templates map[string]Template, tags []string) ([]*Template, error) { + if len(tags) == 0 { + return nil, fmt.Errorf("no tags provided") + } + + var matched []*Template + for _, tmpl := range templates { + if templateMatchesTags(&tmpl, tags) { + t := tmpl // Create a copy to avoid pointer issues + matched = append(matched, &t) + } + } + + if len(matched) == 0 { + return nil, fmt.Errorf("no templates found matching tags: %s", strings.Join(tags, ", ")) + } + + return matched, nil +} + +// templateMatchesTags checks if a template matches any of the given tags. +func templateMatchesTags(tmpl *Template, filterTags []string) bool { + for _, filterTag := range filterTags { + filterTag = strings.TrimSpace(filterTag) + if filterTag == "" { + continue + } + for _, templateTag := range tmpl.Info.Tags { + if strings.EqualFold(templateTag, filterTag) || + strings.Contains(strings.ToLower(templateTag), strings.ToLower(filterTag)) { + return true + } + } + } + return false +} + // GetDockerComposePath finds and returns the docker-compose file path for a given template ID. // It searches through all category directories in the templates repository to locate the template. // Returns the absolute path to the compose file and the working directory. diff --git a/pkg/template/template_test.go b/pkg/template/template_test.go index e224b3a..f6a7071 100644 --- a/pkg/template/template_test.go +++ b/pkg/template/template_test.go @@ -33,6 +33,102 @@ func TestGetByID(t *testing.T) { assert.EqualError(t, err, fmt.Sprintf("template %s not found", noneExistTemplateID)) } +func TestGetByTags(t *testing.T) { + templates := map[string]Template{ + "sqli-template": { + ID: "sqli-template", + Info: Info{ + Name: "SQL Injection Lab", + Tags: []string{"sqli", "web", "owasp"}, + }, + }, + "xss-template": { + ID: "xss-template", + Info: Info{ + Name: "XSS Lab", + Tags: []string{"xss", "web", "owasp"}, + }, + }, + "ssrf-template": { + ID: "ssrf-template", + Info: Info{ + Name: "SSRF Lab", + Tags: []string{"ssrf", "web"}, + }, + }, + } + + // Test single tag match + matched, err := GetByTags(templates, []string{"sqli"}) + assert.NoError(t, err) + assert.Len(t, matched, 1) + assert.Equal(t, "sqli-template", matched[0].ID) + + // Test multiple tags (OR logic) + matched, err = GetByTags(templates, []string{"sqli", "xss"}) + assert.NoError(t, err) + assert.Len(t, matched, 2) + + // Test tag that matches multiple templates + matched, err = GetByTags(templates, []string{"web"}) + assert.NoError(t, err) + assert.Len(t, matched, 3) + + // Test case-insensitive matching + matched, err = GetByTags(templates, []string{"SQLI"}) + assert.NoError(t, err) + assert.Len(t, matched, 1) + + // Test substring matching + matched, err = GetByTags(templates, []string{"owa"}) + assert.NoError(t, err) + assert.Len(t, matched, 2) // sqli-template and xss-template have "owasp" + + // Test no matches + matched, err = GetByTags(templates, []string{"nonexistent"}) + assert.Error(t, err) + assert.Nil(t, matched) + assert.Contains(t, err.Error(), "no templates found matching tags") + + // Test empty tags + matched, err = GetByTags(templates, []string{}) + assert.Error(t, err) + assert.Nil(t, matched) + assert.Contains(t, err.Error(), "no tags provided") + + // Test whitespace-only tags are ignored + matched, err = GetByTags(templates, []string{" ", "sqli"}) + assert.NoError(t, err) + assert.Len(t, matched, 1) +} + +func TestTemplateMatchesTags(t *testing.T) { + tmpl := &Template{ + ID: "test-template", + Info: Info{ + Tags: []string{"sqli", "XSS", "OWASP-Top10"}, + }, + } + + // Exact match + assert.True(t, templateMatchesTags(tmpl, []string{"sqli"})) + + // Case-insensitive match + assert.True(t, templateMatchesTags(tmpl, []string{"SQLI"})) + assert.True(t, templateMatchesTags(tmpl, []string{"xss"})) + + // Substring match + assert.True(t, templateMatchesTags(tmpl, []string{"owasp"})) + assert.True(t, templateMatchesTags(tmpl, []string{"top10"})) + + // No match + assert.False(t, templateMatchesTags(tmpl, []string{"nonexistent"})) + + // Empty filter tags + assert.False(t, templateMatchesTags(tmpl, []string{})) + assert.False(t, templateMatchesTags(tmpl, []string{" ", ""})) +} + // createTestTemplate creates a template directory with an index.yaml file func createTestTemplate(t *testing.T, basePath, templateID string) { t.Helper()