diff --git a/.gitignore b/.gitignore index f7b81002..a1ddd657 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,14 @@ dist # Dependency directories (remove the comment below to include it) # vendor/ + +# Cross-validation testing (not for upstream) +mining/cross_validation_test.go +mining/clustering_cross_validation_test.go +mining/quality_check_cross_validation_test.go +mining/hierarchical_cross_validation_test.go +mining/test_helpers.go +mining/testData/ +mining/cross_validation/ +/alterx +/cmd/alterx/alterx \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..2f8f8d54 --- /dev/null +++ b/Makefile @@ -0,0 +1,31 @@ +.PHONY: all build test lint clean + +# Default target +all: build + +# Build the alterx binary +build: + @echo "Building alterx..." + @go build -o alterx cmd/alterx/main.go + @echo "Build complete: ./alterx" + +# Run tests +test: + @echo "Running tests..." + @go test -v ./... + +# Run linters +lint: + @echo "Running linters..." + @if command -v golangci-lint >/dev/null 2>&1; then \ + golangci-lint run ./...; \ + else \ + echo "golangci-lint not found, running go vet..."; \ + go vet ./...; \ + fi + +# Clean build artifacts +clean: + @echo "Cleaning build artifacts..." + @rm -f alterx + @echo "Clean complete" diff --git a/PATTERN_MINING.md b/PATTERN_MINING.md new file mode 100644 index 00000000..8d838500 --- /dev/null +++ b/PATTERN_MINING.md @@ -0,0 +1,176 @@ +# Pattern Mining Feature + +## Overview + +Pattern mining allows alterx to automatically discover patterns from a list of input domains, eliminating the need to manually define patterns and payloads. + +## Usage + +### Basic Discover Mode + +```bash +# Discover patterns from a list of domains +alterx -l domains.txt -mode discover + +# Limit the output +alterx -l domains.txt -mode discover -limit 100 +``` + +### Advanced Options + +All pattern mining options are grouped under the **"Pattern Mining"** flag group: + +| Flag | Default | Description | +|------|---------|-------------| +| `-m, -mode` | `default` | Pattern mode: `default`, `discover`, or `both` | +| `-min-distance` | `2` | Minimum levenshtein distance for clustering | +| `-max-distance` | `5` | Maximum levenshtein distance for clustering | +| `-pattern-threshold` | `500` | Pattern threshold for filtering low-quality patterns | +| `-quality-ratio` | `25` | Pattern quality ratio threshold | +| `-ngrams-limit` | `0` | Limit number of n-grams to process (0 = all) | + +#### Pattern Modes + +The `-mode` flag controls which patterns are used: + +- **`default`**: Use user-specified or default patterns only (traditional behavior, used when `-mode` not specified) +- **`discover`**: Use only mined patterns from input domains (no defaults) +- **`both`**: Combine mined patterns with defaults for maximum coverage + +### Examples + +**1. Basic discover mode (mined patterns only):** +```bash +# Use -mode discover to mine patterns (no default patterns) +alterx -l subdomains.txt -mode discover -limit 50 +``` + +**2. Combine mined patterns with defaults:** +```bash +# Use -mode both to get maximum coverage +alterx -l subdomains.txt -mode both -limit 100 +``` + +**3. Explicitly use only default patterns:** +```bash +# Use -mode default for traditional behavior +alterx -l subdomains.txt -mode default -limit 50 +``` + +**4. Custom mining parameters:** +```bash +alterx -l subdomains.txt -mode discover \ + -min-distance 3 \ + -max-distance 6 \ + -pattern-threshold 500 \ + -quality-ratio 80 \ + -limit 100 +``` + +**5. Fast mode (limit n-grams):** +```bash +# Process only first 100 n-grams for faster results +alterx -l subdomains.txt -mode discover -ngrams-limit 100 +``` + +**6. Discover and save to file:** +```bash +alterx -l subdomains.txt -mode discover -o permutations.txt +``` + +## Input Requirements + +For optimal pattern discovery: +- **Minimum**: 10 domains (warning shown if fewer) +- **Recommended**: 50+ domains for better pattern diversity +- **Best**: 100+ domains with varied structures + +## How It Works + +The pattern mining algorithm uses two complementary approaches: + +1. **Levenshtein Distance Clustering**: Groups similar subdomains based on edit distance +2. **Hierarchical N-gram Clustering**: Analyzes subdomains at multiple granularity levels + +### Example + +Given input domains: +``` +api-prod.example.com +api-staging.example.com +web-prod.example.com +web-staging.example.com +``` + +Discovered patterns: +``` +api-{{p0}}.{{root}} → payloads: {"p0": ["prod", "staging"]} +web-{{p0}}.{{root}} → payloads: {"p0": ["prod", "staging"]} +{{p0}}.{{root}} → payloads: {"p0": ["api-prod", "api-staging", "web-prod", "web-staging"]} +``` + +Generated permutations: +``` +api-prod.example.com +api-staging.example.com +web-prod.example.com +web-staging.example.com +(and many more combinations...) +``` + +## Architecture + +The implementation uses a clean interface-based design: + +- **`PatternProvider`** interface: Common contract for pattern generation strategies +- **`ManualPatternProvider`**: Traditional mode with user-specified patterns +- **`MinedPatternProvider`**: Discover mode with automatic pattern mining +- **Mutator**: Uses patterns/payloads from provider transparently + +## Backward Compatibility + +Manual mode remains unchanged: +```bash +# Traditional usage still works exactly as before +alterx -l domains.txt -p "{{word}}.{{root}}" -pp 'word=words.txt' +``` + +## Performance Tuning + +### For Large Datasets (1000+ domains) + +```bash +# Reduce distance ranges +alterx -l large-list.txt -mode discover -min-distance 2 -max-distance 4 + +# Limit n-grams for faster processing +alterx -l large-list.txt -mode discover -ngrams-limit 200 +``` + +### For Quality over Speed + +```bash +# Process all n-grams with strict thresholds +alterx -l domains.txt -mode discover \ + -ngrams-limit 0 \ + -pattern-threshold 2000 \ + -quality-ratio 150 +``` + +## Testing + +Run pattern mining tests: +```bash +# Unit tests +go test -v -run TestMinedPatternProvider + +# Integration tests +go test -v -run TestMutatorIntegration_DiscoverMode + +# Cross-validation tests (requires Python) +cd mining && go test -v -run TestPatternDifferences +``` + +## Algorithm Details + +See [mining/README.md](mining/README.md) for detailed algorithm documentation and Python reference implementation comparison. diff --git a/README.md b/README.md index 69704f56..c870d39c 100644 --- a/README.md +++ b/README.md @@ -230,6 +230,15 @@ $ alterx -list tesla.txt -enrich -p '{{word}}-{{year}}.{{suffix}}' -pp word=keyw **For more information, please checkout the release blog** - https://blog.projectdiscovery.io/introducing-alterx-simplifying-active-subdomain-enumeration-with-patterns/ +## Pattern Mining + +The pattern mining implementation in this project is based on the [regulator](https://github.com/cramppet/regulator) project by [@cramppet](https://github.com/cramppet). Regulator is a subdomain pattern mining tool that uses hierarchical clustering algorithms to automatically discover patterns in subdomain datasets. We've adapted and extended these concepts to provide automatic pattern generation capabilities within AlterX. + +### Attribution + +The hierarchical ngram-based clustering approach and pattern mining algorithms are inspired by and adapted from the [regulator project](https://github.com/cramppet/regulator). Special thanks to [@cramppet](https://github.com/cramppet) for the excellent work on subdomain pattern analysis. + +--- Do also check out the below similar open-source projects that may fit in your workflow: diff --git a/algo.go b/algo.go index 38626e21..ca1660d1 100644 --- a/algo.go +++ b/algo.go @@ -18,6 +18,13 @@ func ClusterBomb(payloads *IndexMap, callback func(varMap map[string]interface{} // step 4) At end of recursion len(Vector) == len(payloads).Cap() - 1 // which translates that Vn = {r0,r1,...,rn} and only rn is missing // in this case/situation iterate over all possible values of rn i.e payload.GetNth(n) + + // Debug: Check if payloads is empty + if payloads.Cap() == 0 { + // No payloads to expand - this will cause pattern to be returned unexpanded + return + } + if len(Vector) == payloads.Cap()-1 { // end of vector vectorMap := map[string]interface{}{} diff --git a/cmd/alterx/main.go b/cmd/alterx/main.go index e87b4c96..5b75b42b 100644 --- a/cmd/alterx/main.go +++ b/cmd/alterx/main.go @@ -14,12 +14,18 @@ func main() { cliOpts := runner.ParseFlags() alterOpts := alterx.Options{ - Domains: cliOpts.Domains, - Patterns: cliOpts.Patterns, - Payloads: cliOpts.Payloads, - Limit: cliOpts.Limit, - Enrich: cliOpts.Enrich, // enrich payloads - MaxSize: cliOpts.MaxSize, + Domains: cliOpts.Domains, + Patterns: cliOpts.Patterns, + Payloads: cliOpts.Payloads, + Limit: cliOpts.Limit, + Enrich: cliOpts.Enrich, // enrich payloads + MaxSize: cliOpts.MaxSize, + Mode: cliOpts.Mode, + MinLDist: cliOpts.MinLDist, + MaxLDist: cliOpts.MaxLDist, + PatternThreshold: cliOpts.PatternThreshold, + PatternQualityRatio: cliOpts.PatternQualityRatio, + NgramsLimit: cliOpts.NgramsLimit, } if cliOpts.PermutationConfig != "" { @@ -44,7 +50,11 @@ func main() { gologger.Fatal().Msgf("failed to open output file %v got %v", cliOpts.Output, err) } output = fs - defer fs.Close() + defer func() { + if err := fs.Close(); err != nil { + gologger.Error().Msgf("failed to close output file: %v", err) + } + }() } else { output = os.Stdout } diff --git a/examples/main.go b/examples/main.go index 1dd52388..fd06a48c 100644 --- a/examples/main.go +++ b/examples/main.go @@ -20,5 +20,7 @@ func main() { if err != nil { gologger.Fatal().Msg(err.Error()) } - m.ExecuteWithWriter(os.Stdout) + if err := m.ExecuteWithWriter(os.Stdout); err != nil { + gologger.Fatal().Msgf("failed to execute: %v", err) + } } diff --git a/go.mod b/go.mod index de476e51..9857568d 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module github.com/projectdiscovery/alterx go 1.23.0 require ( + github.com/armon/go-radix v1.0.0 + github.com/ka-weihe/fast-levenshtein v0.0.0-20201227151214-4c99ee36a1ba github.com/projectdiscovery/fasttemplate v0.0.2 github.com/projectdiscovery/goflags v0.1.72 github.com/projectdiscovery/gologger v1.1.45 diff --git a/go.sum b/go.sum index 395363b5..699d2be2 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0 github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/VividCortex/ewma v1.2.0 h1:f58SaIzcDXrSy3kWaHNvuJgJ3Nmz59Zji6XoJR/q1ow= github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4= +github.com/agnivade/levenshtein v1.1.0 h1:n6qGwyHG61v3ABce1rPVZklEYRT8NFpCMrpZdBUbYGM= +github.com/agnivade/levenshtein v1.1.0/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= github.com/akrylysov/pogreb v0.10.1 h1:FqlR8VR7uCbJdfUob916tPM+idpKgeESDXOA1K0DK4w= github.com/akrylysov/pogreb v0.10.1/go.mod h1:pNs6QmpQ1UlTJKDezuRWmaqkgUE2TuU0YTWyqJZ7+lI= github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE= @@ -15,6 +17,10 @@ github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW5 github.com/andybalholm/brotli v1.0.1/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= +github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= +github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI= +github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= @@ -42,6 +48,9 @@ github.com/cnf/structhash v0.0.0-20201127153200-e1b16c1ebc08/go.mod h1:pCxVEbcm3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= +github.com/dgryski/trifles v0.0.0-20200830180326-aaf60a07f6a3 h1:JibukGTEjdN4VMX7YHmXQsLr/gPURUbetlH4E6KvHSU= +github.com/dgryski/trifles v0.0.0-20200830180326-aaf60a07f6a3/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L7HYpRu/0lE3e0BaElwnNO1qkNQxBY= @@ -87,6 +96,8 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/ka-weihe/fast-levenshtein v0.0.0-20201227151214-4c99ee36a1ba h1:keZ4vJpYOVm6yrjLzZ6QgozbEBaT0GjfH30ihbO67+4= +github.com/ka-weihe/fast-levenshtein v0.0.0-20201227151214-4c99ee36a1ba/go.mod h1:kaXTPU4xitQT0rfT7/i9O9Gm8acSh3DXr0p4y3vKqiE= github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.11.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= @@ -211,8 +222,6 @@ github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+F github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/ulikunitz/xz v0.5.8/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/ulikunitz/xz v0.5.9/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= -github.com/ulikunitz/xz v0.5.14 h1:uv/0Bq533iFdnMHZdRBTOlaNMdb1+ZxXIlHDZHIHcvg= -github.com/ulikunitz/xz v0.5.14/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY= github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= @@ -231,6 +240,8 @@ github.com/zcalusic/sysinfo v1.0.2 h1:nwTTo2a+WQ0NXwo0BGRojOJvJ/5XKvQih+2RrtWqfx github.com/zcalusic/sysinfo v1.0.2/go.mod h1:kluzTYflRWo6/tXVMJPdEjShsbPpsFRyy+p1mBQPC30= go.etcd.io/bbolt v1.3.7 h1:j+zJOnnEjF/kyHlDDgGnVL/AIqIJPq8UoB2GSNfkUfQ= go.etcd.io/bbolt v1.3.7/go.mod h1:N9Mkw9X8x5fupy0IKsmuqVtoGDyxsaDlbk4Rd05IAQw= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 51668c92..35f7cf0e 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -28,6 +28,13 @@ type Options struct { Enrich bool Limit int MaxSize int + // Mining/Discovery options + Mode string + MinLDist int + MaxLDist int + PatternThreshold int + PatternQualityRatio int + NgramsLimit int // internal/unexported fields wordlists goflags.RuntimeMap } @@ -60,6 +67,15 @@ func ParseFlags() *Options { flagSet.IntVar(&opts.Limit, "limit", 0, "limit the number of results to return (default 0)"), ) + flagSet.CreateGroup("mining", "Pattern Mining", + flagSet.StringVarP(&opts.Mode, "mode", "m", "", "pattern mode: 'default' (user/default patterns), 'discover' (mined only), 'both' (combined)"), + flagSet.IntVar(&opts.MinLDist, "min-distance", 2, "minimum levenshtein distance for clustering"), + flagSet.IntVar(&opts.MaxLDist, "max-distance", 5, "maximum levenshtein distance for clustering"), + flagSet.IntVar(&opts.PatternThreshold, "pattern-threshold", 500, "pattern threshold for filtering low-quality patterns"), + flagSet.IntVar(&opts.PatternQualityRatio, "quality-ratio", 25, "pattern quality ratio threshold"), + flagSet.IntVar(&opts.NgramsLimit, "ngrams-limit", 0, "limit number of n-grams to process (0 = all)"), + ) + flagSet.CreateGroup("update", "Update", flagSet.CallbackVarP(GetUpdateCallback(), "update", "up", "update alterx to latest version"), flagSet.BoolVarP(&opts.DisableUpdateCheck, "disable-update-check", "duc", false, "disable automatic alterx update check"), diff --git a/mining/clustering.go b/mining/clustering.go new file mode 100644 index 00000000..07428d6d --- /dev/null +++ b/mining/clustering.go @@ -0,0 +1,401 @@ +package mining + +import ( + "sort" + "strings" +) + +// hierarchicalNgramClustering clusters subdomains using a hierarchical approach that combines +// ngram prefix matching, token extraction, and edit distance clustering. +// +// HIERARCHICAL ALGORITHM: +// For each ngram ('a', 'ab', 'b0', etc.): +// 1. Get keys (hostnames) starting with ngram: keys = trie.keys(ngram) +// 2. Chance 1: Generate pattern from ALL ngram keys directly +// 3. Extract prefixes: Get first token from each hostname +// - first_token("api-prod-1.example.com") → "api" +// 4. For each unique prefix: +// a. Get new keys: keys = trie.keys(prefix) (all hostnames starting with this prefix) +// b. Chance 2: Generate pattern from ALL prefix keys directly +// c. Chance 3 (if prefix length > 1): Do edit distance clustering +// - For each k value, compute closures on prefix keys +// - For each closure, generate pattern +// +// HIERARCHY: ngram → keys → prefixes → new keys per prefix → edit distance clustering → patterns +// +// EXAMPLE: +// Given ngram "a" with keys: ["api-prod-1", "api-prod-2", "app-dev-1"] +// Chance 1: Generate pattern from all 3 keys +// Extract prefixes: ["api", "api", "app"] → unique: ["api", "app"] +// For prefix "api": +// - Get keys starting with "api": ["api-prod-1", "api-prod-2"] +// - Chance 2: Generate pattern from these 2 keys +// - Chance 3: Since len("api") > 1, do edit distance clustering with k=2,3,... +// +// For prefix "app": +// - Get keys starting with "app": ["app-dev-1"] +// - Chance 2: Generate pattern (single item, might be skipped) +// - Chance 3: Since len("app") > 1, do edit distance clustering +func (p *PatternMiner) hierarchicalNgramClustering() error { + // Generate all possible unigrams and bigrams for valid subdomain prefixes + unigrams, bigrams := GenerateValidNgrams() + + // Combine all ngrams for processing + allNgrams := append([]string{}, unigrams...) + allNgrams = append(allNgrams, bigrams...) + + // Apply ngram limit to match Python's behavior (ngrams_limit parameter) + // If NgramsLimit is 0, process all ngrams. Otherwise, limit to first N ngrams. + ngramsToProcess := allNgrams + if p.options.NgramsLimit > 0 && len(allNgrams) > p.options.NgramsLimit { + ngramsToProcess = allNgrams[:p.options.NgramsLimit] + } + + // Process each ngram hierarchically + for _, ngram := range ngramsToProcess { + if err := p.processNgramHierarchy(ngram); err != nil { + return err + } + } + + return nil +} + +// processNgramHierarchy processes a single ngram through the hierarchical clustering pipeline. +// +// ALGORITHM (matches Python reference): +// 1. Get all keys matching the ngram prefix +// 2. Generate pattern from ngram-level keys (Chance 1) +// 3. Extract first tokens to get unique prefixes (sorted) +// 4. Process each prefix level with redundancy filtering +// +// REDUNDANCY FILTERING (matches Python): +// +// Python: if last is None or not prefix.startswith(last) +// Skip prefix if it starts with the previous prefix (redundant) +func (p *PatternMiner) processNgramHierarchy(ngram string) error { + // Step 1: Get all keys (subdomains) starting with this ngram + ngramKeys := p.getSubdomainsByPrefix(ngram) + if len(ngramKeys) == 0 { + return nil + } + + // Step 2: Chance 1 - Generate and store pattern from ALL ngram keys directly + p.tryGenerateAndStorePattern(ngramKeys) + + // Step 3: Extract first tokens (prefixes) from all keys + prefixMap := make(map[string]struct{}) + for _, key := range ngramKeys { + prefix := p.extractFirstToken(key) + if prefix != "" { + prefixMap[prefix] = struct{}{} + } + } + + // Convert to sorted slice for redundancy filtering + // Python: prefixes = sorted(list(set([first_token(k) for k in trie.keys(ngram)]))) + prefixes := make([]string, 0, len(prefixMap)) + for prefix := range prefixMap { + prefixes = append(prefixes, prefix) + } + sort.Strings(prefixes) + + // Step 4: Process each prefix with redundancy filtering + // Python: last = None + // if last is None or not prefix.startswith(last): + // if pattern_added: # Only update last when pattern is actually added! + // last = prefix + // else: + // logging.warning(f"Rejecting redundant prefix: {prefix}") + // continue + var last string + for _, prefix := range prefixes { + // Skip if this prefix starts with the previous one (redundant) + if last != "" && strings.HasPrefix(prefix, last) { + // Redundant prefix, skip it + continue + } + + // Process this prefix and check if any pattern was added + patternAdded := p.processPrefixLevel(prefix) + + // Only update last if a pattern was actually added (matches Python behavior) + if patternAdded { + last = prefix + } + } + + return nil +} + +// processPrefixLevel processes clustering at the prefix level. +// +// ALGORITHM (matches Python reference): +// 1. Get all keys matching the prefix +// 2. Generate pattern from prefix-level keys (Chance 2) +// 3. If prefix length > 1, perform edit distance clustering (Chance 3) +// +// PARAMETERS: +// +// prefix - The first token extracted from hostnames (e.g., "api", "web", "app") +// +// RETURNS: +// +// bool - true if at least one pattern was added, false otherwise +func (p *PatternMiner) processPrefixLevel(prefix string) bool { + // Step 1: Get all keys starting with this prefix + prefixKeys := p.getSubdomainsByPrefix(prefix) + if len(prefixKeys) == 0 { + return false + } + + // Step 2: Chance 2 - Generate and store pattern from ALL prefix keys directly + // Python: r = closure_to_regex(args['target'], keys) + // if r not in new_rules and is_good_rule(r, len(keys), ...): + // last = prefix # Only update last when pattern is added + // new_rules.add(r) + patternAdded := p.tryGenerateAndStorePattern(prefixKeys) + + // Step 3: Chance 3 - If prefix length > 1, do edit distance clustering + // Python: if len(prefix) > 1: + if len(prefix) > 1 { + // For each k value (distance threshold), compute closures + for k := p.options.MinLDist; k <= p.options.MaxLDist; k++ { + // Get clusters by levenshtein distance on prefix keys only + clusters, err := p.getLevenshteinClustersForKeys(prefixKeys, k) + if err != nil { + // Log error but continue processing + continue + } + + // For each cluster (closure), generate and store pattern + // Python: for closure in closures: + // r = closure_to_regex(args['target'], closure) + // if r not in new_rules and is_good_rule(r, len(closure), ...): + // new_rules.add(r) + for _, cluster := range clusters { + if p.tryGenerateAndStorePattern(cluster) { + patternAdded = true + } + } + } + } + + return patternAdded +} + +// getSubdomainsByPrefix returns all subdomains that start with the given prefix. +// Uses radix tree for O(k) lookup where k is the number of matching subdomains. +func (p *PatternMiner) getSubdomainsByPrefix(prefix string) []string { + var matches []string + + // WalkPrefix traverses all entries in the tree under the given prefix + p.trie.WalkPrefix(prefix, func(key string, value interface{}) bool { + matches = append(matches, key) + return false // continue walking + }) + + return matches +} + +// levenshteinSubsClustering clusters subdomains by levenshtein distance on subdomain part. +// +// ALGORITHM (matches Python reference): +// For each distance threshold k: +// 1. Get clusters (edit closures) by levenshtein distance +// 2. For each cluster, generate and store pattern if it passes quality checks +// +// This matches Python: +// +// for k in range(args['dist_low'], args['dist_high']): +// closures = edit_closures(known_hosts, delta=k) +// for closure in closures: +// if len(closure) > 1: # Already filtered in getClustersByLevenshteinDistance +// r = closure_to_regex(args['target'], closure) +// if r not in new_rules and is_good_rule(r, len(closure), ...): +// new_rules.add(r) +func (p *PatternMiner) levenshteinSubsClustering() error { + // Get clusters by levenshtein distance starting from min to max + for k := p.options.MinLDist; k <= p.options.MaxLDist; k++ { + clusters, err := p.getClustersByLevenshteinDistance(k) + if err != nil { + return err + } + + // For each cluster, generate and store pattern + for _, cluster := range clusters { + p.tryGenerateAndStorePattern(cluster) + } + } + return nil +} + +// getClustersByLevenshteinDistance computes clusters of subdomains bounded by edit distance. +// +// ALGORITHM: +// For each subdomain 'a', create a cluster containing: +// - The subdomain 'a' itself +// - All subdomains 'b' where distance(a, b) < k +// +// Then deduplicate identical clusters and discard singletons. +// +// EXAMPLE with k=2: +// +// Given subdomains: api, api1, api12 +// Distances: api↔api1=1, api1↔api12=1, api↔api12=2 +// +// Step 1: Build cluster for each subdomain +// +// Cluster from 'api': {api, api1} (api1 dist=1 < 2, api12 dist=2 NOT < 2) +// Cluster from 'api1': {api1, api, api12} (api dist=1 < 2, api12 dist=1 < 2) +// Cluster from 'api12': {api12, api1} (api1 dist=1 < 2, api dist=2 NOT < 2) +// +// Step 2: Deduplicate (no identical clusters in this case) +// +// Result: [{api, api1}, {api1, api, api12}, {api12, api1}] +// +// Step 3: Filter singletons (none in this case) +// +// Final: [{api, api1}, {api1, api, api12}, {api12, api1}] +// +// IMPORTANT PROPERTY: +// Items in a cluster don't need to be close to EACH OTHER, only to the CENTER item. +// In the example above, {api1, api, api12} is a valid cluster even though api↔api12=2 (not < k), +// because both api and api12 are within distance < 2 from the center item api1. +// +// PARAMETERS: +// +// k - Distance threshold (strictly less than, not <=) +// +// RETURNS: +// +// Clusters with 2+ items (singletons are discarded) +func (p *PatternMiner) getClustersByLevenshteinDistance(k int) ([][]string, error) { + if len(p.subdomains) == 0 { + return nil, nil + } + + type cluster map[string]struct{} + var result []cluster + + // For each item 'a', create a cluster containing all items within distance < k from 'a' + for _, a := range p.subdomains { + currentCluster := make(cluster) + currentCluster[a] = struct{}{} // Always include the center item itself + + // Find all items 'b' within distance < k from center item 'a' + for _, b := range p.subdomains { + if a == b { + continue // Already added above + } + + edge := NewEdge(a, b) + if dist, ok := p.distanceMap[edge]; ok && dist < k { + currentCluster[b] = struct{}{} + } + } + + // Deduplicate: Check if this exact cluster already exists in results + found := false + for _, existingCluster := range result { + if clustersEqual_internal(currentCluster, existingCluster) { + found = true + break + } + } + + if !found { + result = append(result, currentCluster) + } + } + + // Convert to slice format and filter out singleton clusters + finalResult := make([][]string, 0, len(result)) + for _, c := range result { + if len(c) > 1 { + items := make([]string, 0, len(c)) + for item := range c { + items = append(items, item) + } + finalResult = append(finalResult, items) + } + } + + return finalResult, nil +} + +// getLevenshteinClustersForKeys computes levenshtein distance clusters for a specific subset of keys. +// This is similar to getClustersByLevenshteinDistance but operates on a provided subset of keys +// instead of all subdomains. +// +// ALGORITHM: +// For each key 'a' in the provided keys: +// +// Create a cluster containing: +// - The key 'a' itself +// - All keys 'b' from the same subset where distance(a, b) < k +// +// PARAMETERS: +// +// keys - Subset of subdomains to cluster +// k - Distance threshold (strictly less than, not <=) +// +// RETURNS: +// +// Clusters with 2+ items (singletons are discarded) +// +// TODO: Implement levenshtein clustering on subset of keys +func (p *PatternMiner) getLevenshteinClustersForKeys(keys []string, k int) ([][]string, error) { + if len(keys) == 0 { + return nil, nil + } + + type cluster map[string]struct{} + var result []cluster + + // For each item 'a' in keys, create a cluster containing all items within distance < k from 'a' + for _, a := range keys { + currentCluster := make(cluster) + currentCluster[a] = struct{}{} // Always include the center item itself + + // Find all items 'b' within distance < k from center item 'a' + for _, b := range keys { + if a == b { + continue // Already added above + } + + // Look up distance from pre-computed distance map + edge := NewEdge(a, b) + if dist, ok := p.distanceMap[edge]; ok && dist < k { + currentCluster[b] = struct{}{} + } + } + + // Deduplicate: Check if this exact cluster already exists in results + found := false + for _, existingCluster := range result { + if clustersEqual_internal(currentCluster, existingCluster) { + found = true + break + } + } + + if !found { + result = append(result, currentCluster) + } + } + + // Convert to slice format and filter out singleton clusters + finalResult := make([][]string, 0, len(result)) + for _, c := range result { + if len(c) > 1 { + items := make([]string, 0, len(c)) + for item := range c { + items = append(items, item) + } + finalResult = append(finalResult, items) + } + } + + return finalResult, nil +} diff --git a/mining/hierarchical_clustering_test.go b/mining/hierarchical_clustering_test.go new file mode 100644 index 00000000..33c30750 --- /dev/null +++ b/mining/hierarchical_clustering_test.go @@ -0,0 +1,257 @@ +package mining + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestPatternDeduplication verifies that duplicate patterns are not stored twice +func TestPatternDeduplication(t *testing.T) { + domains := []string{ + "api-prod.example.com", + "api-staging.example.com", + } + + pm, err := NewPatternMiner(domains, &Options{ + MinLDist: 2, + MaxLDist: 10, + PatternThreshold: 1000, + PatternQualityRatio: 100, + }) + require.NoError(t, err) + + // Generate the same pattern twice + success1 := pm.tryGenerateAndStorePattern([]string{"api-prod", "api-staging"}) + success2 := pm.tryGenerateAndStorePattern([]string{"api-prod", "api-staging"}) + + assert.True(t, success1, "First pattern should be stored") + assert.False(t, success2, "Second identical pattern should be rejected (duplicate)") + + results := pm.GetResults() + assert.Len(t, results, 1, "Should only have one pattern (deduplication working)") +} + +// TestRedundantPrefixFiltering verifies that redundant prefixes are skipped +func TestRedundantPrefixFiltering(t *testing.T) { + // Test data where we have redundant prefixes + // "api" and "api-prod" where "api-prod" starts with "api" + domains := []string{ + "api.example.com", + "api-prod.example.com", + "api-prod-1.example.com", + "api-staging.example.com", + } + + pm, err := NewPatternMiner(domains, &Options{ + MinLDist: 2, + MaxLDist: 3, + PatternThreshold: 1000, + PatternQualityRatio: 100, + }) + require.NoError(t, err) + + // Manually test processNgramHierarchy with ngram "a" + err = pm.processNgramHierarchy("a") + require.NoError(t, err) + + results := pm.GetResults() + + // Verify patterns were generated + assert.Greater(t, len(results), 0, "Should generate at least one pattern") + + // Check that we're not generating redundant patterns + // This is more of a smoke test - the real test is that it doesn't error + t.Logf("Generated %d patterns", len(results)) + for i, pattern := range results { + t.Logf("Pattern %d: %s with %d payloads", i+1, pattern.Pattern, len(pattern.Payloads)) + } +} + +// TestHierarchicalClustering verifies the full hierarchical clustering pipeline +func TestHierarchicalClustering(t *testing.T) { + domains := []string{ + "api-prod-1.example.com", + "api-prod-2.example.com", + "api-staging-1.example.com", + "web-prod.example.com", + "web-staging.example.com", + } + + pm, err := NewPatternMiner(domains, &Options{ + MinLDist: 2, + MaxLDist: 5, + PatternThreshold: 1000, + PatternQualityRatio: 100, + }) + require.NoError(t, err) + + // Run hierarchical clustering + err = pm.hierarchicalNgramClustering() + require.NoError(t, err) + + results := pm.GetResults() + + // Should generate multiple patterns at different levels + assert.Greater(t, len(results), 0, "Should generate at least one pattern") + + t.Logf("Generated %d total patterns:", len(results)) + for i, pattern := range results { + t.Logf(" %d. %s (payloads: %d)", i+1, pattern.Pattern, len(pattern.Payloads)) + } +} + +// TestFullExecutePipeline tests the complete Execute() workflow +func TestFullExecutePipeline(t *testing.T) { + domains := []string{ + "api-prod.example.com", + "api-staging.example.com", + "api-dev.example.com", + "web-prod.example.com", + "web-staging.example.com", + } + + pm, err := NewPatternMiner(domains, &Options{ + MinLDist: 2, + MaxLDist: 5, + PatternThreshold: 1000, + PatternQualityRatio: 100, + }) + require.NoError(t, err) + + // Execute full pipeline + err = pm.Execute() + require.NoError(t, err) + + results := pm.GetResults() + + // Should generate patterns + assert.Greater(t, len(results), 0, "Execute should generate patterns") + + // Verify all patterns are unique (deduplication working) + seenPatterns := make(map[string]bool) + for _, pattern := range results { + assert.False(t, seenPatterns[pattern.Pattern], "Pattern %s appears twice (deduplication failed)", pattern.Pattern) + seenPatterns[pattern.Pattern] = true + } + + t.Logf("Execute generated %d unique patterns:", len(results)) + for i, pattern := range results { + combinations := 1 + for _, payload := range pattern.Payloads { + combinations *= len(payload) + } + t.Logf(" %d. %s → %d combinations", i+1, pattern.Pattern, combinations) + } +} + +// TestQualityFilteringDuringClustering verifies bad patterns are rejected during clustering +func TestQualityFilteringDuringClustering(t *testing.T) { + // Create patterns that would be too generic + domains := []string{ + "a.example.com", + "b.example.com", + "c.example.com", + "d.example.com", + "e.example.com", + } + + pm, err := NewPatternMiner(domains, &Options{ + MinLDist: 2, + MaxLDist: 5, + PatternThreshold: 2, // Very strict - reject patterns with >2 combinations + PatternQualityRatio: 0.5, // Strict ratio + }) + require.NoError(t, err) + + // Try to generate a pattern from all 5 - should be rejected + // Pattern {{p0}} with 5 values = 5 combinations, 5 inputs + // Ratio = 5/5 = 1.0 > 0.5, and 5 > 2, so rejected + success := pm.tryGenerateAndStorePattern([]string{"a", "b", "c", "d", "e"}) + assert.False(t, success, "Generic pattern should be rejected") + + // Try with just 2 - Pattern {{p0}} with 2 values = 2 combinations, 2 inputs + // Ratio = 2/2 = 1.0 which is NOT < 0.5, and 2 is NOT < 2 + // So this will also be rejected! Let me use 3 threshold instead + pm.options.PatternThreshold = 3 // Now 2 < 3 will pass + success = pm.tryGenerateAndStorePattern([]string{"a", "b"}) + assert.True(t, success, "Simple pattern should pass with threshold=3") + + results := pm.GetResults() + assert.Len(t, results, 1, "Should only have the accepted pattern") +} + +// TestMaxLengthFiltering verifies patterns exceeding max length are rejected +func TestMaxLengthFiltering(t *testing.T) { + domains := []string{ + "very-long-subdomain-name-here-prod.example.com", + "very-long-subdomain-name-here-staging.example.com", + } + + pm, err := NewPatternMiner(domains, &Options{ + MinLDist: 2, + MaxLDist: 10, + PatternThreshold: 1000, + PatternQualityRatio: 100, + MaxPatternLength: 20, // Pattern will be longer than this + }) + require.NoError(t, err) + + success := pm.tryGenerateAndStorePattern([]string{ + "very-long-subdomain-name-here-prod", + "very-long-subdomain-name-here-staging", + }) + + assert.False(t, success, "Long pattern should be rejected") + assert.Len(t, pm.GetResults(), 0, "No patterns should be stored") +} + +// TestStorePattern verifies the storePattern deduplication logic +func TestStorePattern(t *testing.T) { + pm := &PatternMiner{ + results: make([]*DSLPattern, 0), + seenPatterns: make(map[string]struct{}), + } + + pattern1 := &DSLPattern{ + Pattern: "api{{p0}}", + Payloads: map[string][]string{ + "p0": {"-prod", "-staging"}, + }, + } + + pattern2 := &DSLPattern{ + Pattern: "api{{p0}}", // Same pattern string + Payloads: map[string][]string{ + "p0": {"-dev", "-test"}, // Different payloads but same pattern string + }, + } + + pattern3 := &DSLPattern{ + Pattern: "web{{p0}}", // Different pattern + Payloads: map[string][]string{ + "p0": {"-prod", "-staging"}, + }, + } + + // First pattern should be stored + stored := pm.storePattern(pattern1) + assert.True(t, stored, "First pattern should be stored") + assert.Len(t, pm.results, 1) + + // Second pattern with same pattern string should be rejected + stored = pm.storePattern(pattern2) + assert.False(t, stored, "Duplicate pattern string should be rejected") + assert.Len(t, pm.results, 1, "Should still have only 1 pattern") + + // Third pattern with different string should be stored + stored = pm.storePattern(pattern3) + assert.True(t, stored, "Different pattern should be stored") + assert.Len(t, pm.results, 2) + + // Nil pattern should not be stored + stored = pm.storePattern(nil) + assert.False(t, stored, "Nil pattern should not be stored") + assert.Len(t, pm.results, 2) +} diff --git a/mining/pattern_diff_test.go b/mining/pattern_diff_test.go new file mode 100644 index 00000000..b5760693 --- /dev/null +++ b/mining/pattern_diff_test.go @@ -0,0 +1,191 @@ +package mining + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestPatternDifferences identifies patterns unique to Python vs unique to Go +func TestPatternDifferences(t *testing.T) { + testCases := []string{"nuclei.sh", "projectdiscovery.io", "tesla.com"} + + for _, dataFile := range testCases { + t.Run(dataFile, func(t *testing.T) { + // Load test data + domains := loadTestData(t, dataFile) + if len(domains) == 0 { + t.Skip("No test data available") + } + + analyzePatternDifferences(t, domains) + }) + } +} + +func analyzePatternDifferences(t *testing.T, domains []string) { + + // Run Python pattern generation + pythonPatterns := runPythonHierarchicalPatterns(t, domains, 2, 5, 1000, 100, 100) + + // Run Go pattern generation (with same ngram limit as Python test) + pm, err := NewPatternMiner(domains, &Options{ + MinLDist: 2, + MaxLDist: 5, + PatternThreshold: 1000, + PatternQualityRatio: 100, + NgramsLimit: 100, // Match Python's ngrams_limit parameter + }) + require.NoError(t, err) + + err = pm.hierarchicalNgramClustering() + require.NoError(t, err) + + goPatterns := pm.GetResults() + + // Convert to sets + pythonSet := make(map[string]struct{}) + for _, p := range pythonPatterns.Patterns { + pythonSet[p] = struct{}{} + } + + goSet := make(map[string]struct{}) + for _, p := range goPatterns { + goSet[p.Pattern] = struct{}{} + } + + // Find patterns unique to Python + pythonOnly := []string{} + for pattern := range pythonSet { + if _, exists := goSet[pattern]; !exists { + pythonOnly = append(pythonOnly, pattern) + } + } + + // Find patterns unique to Go + goOnly := []string{} + for pattern := range goSet { + if _, exists := pythonSet[pattern]; !exists { + goOnly = append(goOnly, pattern) + } + } + + // Find common patterns + common := []string{} + for pattern := range pythonSet { + if _, exists := goSet[pattern]; exists { + common = append(common, pattern) + } + } + + // Report findings + t.Logf("Python total: %d patterns", len(pythonSet)) + t.Logf("Go total: %d patterns", len(goSet)) + t.Logf("Common patterns: %d", len(common)) + t.Logf("Unique to Python: %d", len(pythonOnly)) + t.Logf("Unique to Go: %d", len(goOnly)) + + if len(pythonOnly) > 0 { + t.Logf("\nPatterns ONLY in Python (first 20):") + for i, p := range pythonOnly { + if i >= 20 { + t.Logf(" ... and %d more", len(pythonOnly)-20) + break + } + t.Logf(" - %s", p) + } + } + + if len(goOnly) > 0 { + maxShow := 30 + if len(goOnly) < maxShow { + maxShow = len(goOnly) + } + t.Logf("\nPatterns ONLY in Go (first %d with examples):", maxShow) + count := 0 + for _, goPattern := range goPatterns { + if _, exists := pythonSet[goPattern.Pattern]; !exists { + // This pattern is unique to Go + t.Logf("\n Pattern: %s", goPattern.Pattern) + + // Show payloads + if len(goPattern.Payloads) > 0 { + for key, values := range goPattern.Payloads { + if len(values) <= 5 { + t.Logf(" %s: %v", key, values) + } else { + t.Logf(" %s: [%s, %s, %s, ... and %d more]", + key, values[0], values[1], values[2], len(values)-3) + } + } + } + + // Generate and show example combinations + examples := generateExamples(goPattern, 3) + if len(examples) > 0 { + t.Logf(" Examples: %v", examples) + } + + count++ + if count >= maxShow { + if len(goOnly) > maxShow { + t.Logf("\n ... and %d more extra patterns", len(goOnly)-maxShow) + } + break + } + } + } + } +} + +// generateExamples generates example strings from a pattern (up to maxExamples) +func generateExamples(pattern *DSLPattern, maxExamples int) []string { + if pattern == nil || len(pattern.Payloads) == 0 { + return []string{pattern.Pattern} + } + + // Simple case: single payload + if len(pattern.Payloads) == 1 { + var key string + var values []string + for k, v := range pattern.Payloads { + key = k + values = v + break + } + + examples := []string{} + for i := 0; i < len(values) && i < maxExamples; i++ { + example := pattern.Pattern + placeholder := "{{" + key + "}}" + example = replaceFirst(example, placeholder, values[i]) + examples = append(examples, example) + } + return examples + } + + // Multiple payloads - just show first combination + example := pattern.Pattern + for key, values := range pattern.Payloads { + if len(values) > 0 { + placeholder := "{{" + key + "}}" + example = replaceFirst(example, placeholder, values[0]) + } + } + return []string{example} +} + +func replaceFirst(s, old, new string) string { + // Simple string replace for first occurrence + idx := 0 + for i := 0; i <= len(s)-len(old); i++ { + if s[i:i+len(old)] == old { + idx = i + break + } + } + if idx >= 0 && idx+len(old) <= len(s) { + return s[:idx] + new + s[idx+len(old):] + } + return s +} diff --git a/mining/pattern_generation.go b/mining/pattern_generation.go new file mode 100644 index 00000000..e502a59d --- /dev/null +++ b/mining/pattern_generation.go @@ -0,0 +1,523 @@ +package mining + +import ( + "fmt" + "sort" + "strings" +) + +// Mined pattern representation in DSL format +type DSLPattern struct { + ID string `json:"id"` + Metadata map[string]interface{} `json:"metadata"` + Pattern string `json:"pattern"` + Payloads map[string][]string `json:"payloads"` +} + +// generatePattern generates a DSL pattern from a set of subdomain strings using tokenization. +// +// ALGORITHM (matches Python reference implementation): +// 1. Tokenize subdomains into hierarchical tokens +// 2. Analyze token alignment to classify positions as static/variable +// 3. Build DSL pattern string from alignment analysis +// 4. Extract payload values for variable positions +// 5. Apply quality checks +// +// EXAMPLES: +// +// Input: ["api-prod", "api-staging"] +// Tokenized: [["api", "-prod"], ["api", "-staging"]] +// Analysis: Position 0 = static("api"), Position 1 = variable(["-prod", "-staging"]) +// Output: DSLPattern{ +// Pattern: "api{{p0}}", +// Payloads: {"p0": ["-prod", "-staging"]} +// } +// +// Input: ["api-prod-1", "api-prod-2", "api-staging-1"] +// Tokenized: [["api", "-prod", "-1"], ["api", "-prod", "-2"], ["api", "-staging", "-1"]] +// Analysis: pos0=static("api"), pos1=variable(["-prod","-staging"]), pos2=variable(["-1","-2"]) +// Output: DSLPattern{ +// Pattern: "api{{p0}}{{p1}}", +// Payloads: {"p0": ["-prod", "-staging"], "p1": ["-1", "-2"]} +// } +// +// RETURNS: +// - *DSLPattern with pattern string and payloads map +// - error if generation fails +func (p *PatternMiner) generatePattern(subdomains []string) (*DSLPattern, error) { + if len(subdomains) == 0 { + return nil, nil + } + + // Single subdomain - return as-is (static pattern) + if len(subdomains) == 1 { + return &DSLPattern{ + Pattern: subdomains[0], + Payloads: make(map[string][]string), + Metadata: make(map[string]interface{}), + }, nil + } + + // Sort for consistency (matches Python) + sorted := make([]string, len(subdomains)) + copy(sorted, subdomains) + sort.Strings(sorted) + + // Check if all subdomains are identical + allSame := true + for _, s := range sorted[1:] { + if s != sorted[0] { + allSame = false + break + } + } + + // If all subdomains are identical, return as static pattern (no variables) + if allSame { + return &DSLPattern{ + Pattern: sorted[0], + Payloads: make(map[string][]string), + Metadata: make(map[string]interface{}), + }, nil + } + + // STEP 1: Tokenize subdomains into hierarchical structure + tokenized := Tokenize(sorted) + + // STEP 2: Analyze token alignment across all subdomains + // This classifies each position as static (same value) or variable (different values) + levelPositions := p.analyzeTokenAlignment(tokenized) + + // STEP 3: Build DSL pattern string from alignment analysis + // Static positions become literals, variable positions become {{p0}}, {{p1}}, etc. + pattern := p.buildDSLPattern(levelPositions) + + // STEP 4: Extract payload values for each variable position + // Creates map of variable_name → []possible_values + payloads := p.extractPayloads(levelPositions, tokenized) + + // Create DSL pattern structure + dslPattern := &DSLPattern{ + Pattern: pattern, + Payloads: payloads, + Metadata: make(map[string]interface{}), + } + + // STEP 5: Quality checks + + // Max length check + if p.options.MaxPatternLength > 0 && len(pattern) > p.options.MaxPatternLength { + return nil, nil + } + + // Pattern quality check (threshold and ratio) + if !p.isGoodPattern(dslPattern, len(subdomains)) { + return nil, nil + } + + return dslPattern, nil +} + +// analyzeTokenAlignment analyzes token positions across multiple tokenized subdomains. +// +// ALGORITHM: +// 1. Build hierarchical map: levels[levelIdx][positionIdx] = set of unique tokens +// 2. For each level and position, collect all token values across all subdomains +// 3. Classify each position: +// - STATIC: len(unique_values) == 1 (all subdomains have same token) +// - VARIABLE: len(unique_values) > 1 (different tokens exist) +// 4. Detect OPTIONAL positions and levels: +// - Position optional: not all subdomains (with that level) have token at that position +// - Level optional: not all subdomains have that level +// +// SPECIAL CASE EXAMPLES: +// +// Example 1 - Optional Position: +// +// Input: ["api-prod", "api"] +// Level 0, Position 1: only "api-prod" has "-prod" → Position 1 is OPTIONAL +// Result: pattern "api{{p0}}", payloads: {"p0": ["-prod", ""]} +// +// Example 2 - Optional Level: +// +// Input: ["api.dev", "api"] +// Level 1: only "api.dev" has "dev" → Level 1 is OPTIONAL +// Result: pattern "api.{{p0}}", payloads: {"p0": ["dev", ""]} +// +// Example 3 - Variable Position: +// +// Input: ["api-prod-1", "api-staging-2"] +// Level 0, Position 1: has {"-prod", "-staging"} → VARIABLE +// Level 0, Position 2: has {"-1", "-2"} → VARIABLE +// Result: pattern "api{{p0}}{{p1}}" +// +// RETURNS: []LevelPosition with classification and optionality metadata +func (p *PatternMiner) analyzeTokenAlignment(tokenized []TokenizedSubdomain) []LevelPosition { + if len(tokenized) == 0 { + return nil + } + + // Build hierarchical map: levels[levelIdx][positionIdx] = set of unique values + // This allows us to see all variations at each position to detect STATIC vs VARIABLE + levels := make(map[int]map[int]map[string]struct{}) + + // Track all occurrences (including duplicates) to detect OPTIONAL positions + // If len(optional[level][pos]) < totalMembers, position is optional + optional := make(map[int]map[int][]string) + + // STEP 1: Collect all tokens at each level and position across all subdomains + for _, ts := range tokenized { + for levelIdx, level := range ts.Levels { + if _, ok := levels[levelIdx]; !ok { + levels[levelIdx] = make(map[int]map[string]struct{}) + optional[levelIdx] = make(map[int][]string) + } + + for posIdx, token := range level.Tokens { + if _, ok := levels[levelIdx][posIdx]; !ok { + levels[levelIdx][posIdx] = make(map[string]struct{}) + optional[levelIdx][posIdx] = []string{} + } + levels[levelIdx][posIdx][token] = struct{}{} // unique values + optional[levelIdx][posIdx] = append(optional[levelIdx][posIdx], token) // all occurrences + } + } + } + + // STEP 2: Build LevelPosition structures with classification + result := make([]LevelPosition, 0) + totalMembers := len(tokenized) + varCounter := 0 // Sequential counter for variable names: p0, p1, p2, ... + + for levelIdx := 0; levelIdx < len(levels); levelIdx++ { + levelData := levels[levelIdx] + lp := LevelPosition{ + LevelIndex: levelIdx, + Positions: make([]TokenPosition, 0), + } + + // SPECIAL CASE: Detect if entire level is optional + // Level is optional when some subdomains lack this level entirely + // Example: ["api", "api.dev"] → Level 1 is optional (only second has it) + membersWithLevel := 0 + for _, ts := range tokenized { + if levelIdx < len(ts.Levels) { + membersWithLevel++ + } + } + lp.IsOptional = membersWithLevel < totalMembers + + // Analyze each position in this level + for posIdx := 0; posIdx < len(levelData); posIdx++ { + if tokens, ok := levelData[posIdx]; ok { + tp := TokenPosition{ + Index: posIdx, + Values: make([]string, 0, len(tokens)), + } + + // Collect unique values at this position + for token := range tokens { + tp.Values = append(tp.Values, token) + } + + // STEP 1: Detect if position is optional + // Position is optional when NOT ALL subdomains (that have this level) + // have a token at this position + // Example: ["api-prod", "api"] → Position 1 ("-prod") is optional + positionCount := len(optional[levelIdx][posIdx]) + tp.IsOptional = positionCount < totalMembers + + // STEP 2: Classify position as Static or Variable + // A position is VARIABLE if: + // 1. Multiple unique values exist (len > 1), OR + // 2. Position is optional (needs empty string in payload) + // + // CRITICAL: Optional positions with single value must be VARIABLE + // Example: ["api-prod", "api"] → position 1 has ["-prod"] but is optional + // Must be VARIABLE to allow both "api-prod" and "api" generation + if len(tp.Values) > 1 || tp.IsOptional { + tp.Type = TokenPositionVariable + // Variable positions get placeholder name: p0, p1, p2, ... + tp.VarName = fmt.Sprintf("p%d", varCounter) + varCounter++ + } else { + tp.Type = TokenPositionStatic + // Static positions use literal value in pattern + } + + lp.Positions = append(lp.Positions, tp) + } + } + + result = append(result, lp) + } + + return result +} + +// TokenPosition represents metadata about a token position in the pattern. +type TokenPosition struct { + Index int // Position index in token array + Type TokenPositionType // Static or Variable + Values []string // All values seen at this position + VarName string // Variable name if Type is Variable (e.g., "p0", "p1") + IsOptional bool // Whether this position is optional (not all members have it) +} + +// LevelPosition represents all token positions within a single hierarchical level. +type LevelPosition struct { + LevelIndex int // Index of this level in the hierarchy (0 = leftmost subdomain part) + Positions []TokenPosition // Token positions within this level + IsOptional bool // Whether this entire level is optional +} + +// TokenPositionType indicates whether a token position is static or variable. +type TokenPositionType int + +const ( + // TokenPositionStatic indicates all subdomains have same value at this position + TokenPositionStatic TokenPositionType = iota + // TokenPositionVariable indicates subdomains have different values at this position + TokenPositionVariable +) + +// buildDSLPattern constructs a DSL pattern string from level position analysis. +// +// ALGORITHM: +// 1. Iterate through levels and their token positions +// 2. For static positions: use literal value +// 3. For variable positions: use placeholder syntax (e.g., {{p0}}) +// 4. Join levels with dots, positions within levels directly +// +// EXAMPLE: +// +// Input: Level 0: [Static("api"), Variable("p0"), Variable("p1")] +// Output: "api{{p0}}{{p1}}" +func (p *PatternMiner) buildDSLPattern(levelPositions []LevelPosition) string { + if len(levelPositions) == 0 { + return "" + } + + levels := make([]string, 0, len(levelPositions)) + + for _, lp := range levelPositions { + levelPattern := "" + + for _, tp := range lp.Positions { + if tp.Type == TokenPositionStatic { + // Use literal value for static positions + if len(tp.Values) > 0 { + levelPattern += tp.Values[0] + } + } else { + // Use placeholder for variable positions + levelPattern += "{{" + tp.VarName + "}}" + } + } + + levels = append(levels, levelPattern) + } + + // Join levels with dots + return strings.Join(levels, ".") +} + +// extractPayloads extracts payload values for each variable in the pattern. +// +// ALGORITHM: +// 1. For each variable position in the pattern +// 2. Collect all unique values seen at that position +// 3. IMPORTANT: Add empty string "" if position is optional +// (this allows the variable to be omitted when generating domains) +// 4. Build a map of variable_name → []values +// +// EXAMPLES: +// +// Example 1 - Required positions: +// +// Pattern: "api{{p0}}{{p1}}" +// Subdomains: ["api-prod-1", "api-prod-2", "api-staging-1"] +// Output: {"p0": ["-prod", "-staging"], "p1": ["-1", "-2"]} +// +// Example 2 - Optional position: +// +// Pattern: "api{{p0}}" +// Subdomains: ["api-prod", "api"] (second one lacks "-prod") +// Output: {"p0": ["-prod", ""]} ← Note: "" allows generation of "api" +// +// Example 3 - Optional level: +// +// Pattern: "api.{{p0}}" +// Subdomains: ["api.dev", "api"] (second one lacks .dev level) +// Output: {"p0": ["dev", ""]} ← Note: "" allows generation of "api" +func (p *PatternMiner) extractPayloads(levelPositions []LevelPosition, tokenized []TokenizedSubdomain) map[string][]string { + payloads := make(map[string][]string) + + for _, lp := range levelPositions { + for _, tp := range lp.Positions { + // Only extract payloads for VARIABLE positions + // Static positions don't need payloads (they use literal values) + if tp.Type == TokenPositionVariable { + // Collect unique values + uniqueValues := make(map[string]struct{}) + + for _, val := range tp.Values { + uniqueValues[val] = struct{}{} + } + + // SPECIAL CASE: Add empty string for optional positions + // This allows pattern generator to omit the variable + // Example: {"p0": ["-prod", ""]} allows both "api-prod" and "api" + if tp.IsOptional { + uniqueValues[""] = struct{}{} + } + + // Convert to slice + values := make([]string, 0, len(uniqueValues)) + for val := range uniqueValues { + values = append(values, val) + } + + payloads[tp.VarName] = values + } + } + } + + return payloads +} + +// tryGenerateAndStorePattern attempts to generate a pattern from subdomains and store it. +// +// ALGORITHM (matches Python workflow): +// 1. Generate pattern from subdomain closure +// 2. If pattern passes quality checks, store it (with deduplication) +// 3. Return true if pattern was generated and stored +// +// This implements the Python pattern generation and storage flow: +// +// r = closure_to_regex(args['target'], closure) +// if r not in new_rules and is_good_rule(r, len(closure), ...): +// new_rules.add(r) +func (p *PatternMiner) tryGenerateAndStorePattern(subdomains []string) bool { + // Generate pattern (includes quality checks) + pattern, err := p.generatePattern(subdomains) + if err != nil || pattern == nil { + return false // Pattern generation failed or was rejected + } + + // Store pattern (with deduplication) + return p.storePattern(pattern) +} + +// storePattern stores a validated pattern in the results collection with deduplication. +// +// ALGORITHM (matches Python: if r not in new_rules): +// 1. Check if pattern already seen (deduplication) +// 2. If new, add to results and mark as seen +// 3. Return true if stored, false if duplicate +// +// This implements the Python pattern storage logic: +// +// if r not in new_rules: +// new_rules.add(r) +func (p *PatternMiner) storePattern(pattern *DSLPattern) bool { + if pattern == nil { + return false + } + + // Check if we've already generated this pattern (deduplication) + if _, exists := p.seenPatterns[pattern.Pattern]; exists { + return false // Duplicate pattern, skip + } + + // Mark pattern as seen + p.seenPatterns[pattern.Pattern] = struct{}{} + + // Add to results collection + p.results = append(p.results, pattern) + + return true +} + +// isGoodPattern applies quality checks to determine if a pattern is acceptable. +// +// ALGORITHM (matches Python is_good_rule): +// 1. Calculate total combinations: product of all payload lengths (clusterbomb style) +// 2. Apply two checks: +// - Absolute check: nwords < threshold (reject if generates too many) +// - Ratio check: (nwords/nkeys) < max_ratio (reject if expansion ratio too high) +// 3. Pattern is good if: nwords < threshold OR ratio < max_ratio +// +// PARAMETERS: +// - pattern: The DSL pattern to evaluate +// - nkeys: Number of input subdomains used to generate this pattern +// +// EXAMPLE: +// +// Pattern: "api{{p0}}.{{p1}}" with payloads {p0: ["-prod", "-staging"], p1: ["dev", "staging"]} +// nwords = 2 × 2 = 4 combinations +// nkeys = 2 (original subdomains) +// ratio = 4/2 = 2.0 +// +// If threshold=100 and max_ratio=10: +// - 4 < 100 ✓ (passes absolute check) +// - 2.0 < 10 ✓ (passes ratio check) +// → Pattern is GOOD +// +// RETURNS: +// - true if pattern meets quality criteria +// - false if pattern is too generic (should be discarded) +func (p *PatternMiner) isGoodPattern(pattern *DSLPattern, nkeys int) bool { + // Calculate total number of combinations (clusterbomb style) + nwords := p.calculateCombinations(pattern) + + threshold := int(p.options.PatternThreshold) + maxRatio := p.options.PatternQualityRatio + + // Pattern is good if it's below threshold OR has acceptable ratio + // This matches Python: return nwords < threshold or (nwords/nkeys) < max_ratio + if threshold > 0 && nwords < threshold { + return true + } + + if maxRatio > 0 && nkeys > 0 { + ratio := float64(nwords) / float64(nkeys) + return ratio < maxRatio + } + + // If no thresholds configured, accept all patterns + return true +} + +// calculateCombinations calculates total number of output combinations for a DSL pattern. +// +// ALGORITHM: +// +// Total combinations = product of all payload lengths (clusterbomb multiplication) +// +// EXAMPLE: +// +// Pattern: "api{{p0}}.{{p1}}" +// Payloads: {p0: ["-prod", "-staging", "-dev"], p1: ["us", "eu"]} +// Total = 3 × 2 = 6 combinations: +// - api-prod.us +// - api-prod.eu +// - api-staging.us +// - api-staging.eu +// - api-dev.us +// - api-dev.eu +// +// RETURNS: +// - Total number of possible combinations +func (p *PatternMiner) calculateCombinations(pattern *DSLPattern) int { + if len(pattern.Payloads) == 0 { + return 1 // Static pattern generates 1 output + } + + total := 1 + for _, values := range pattern.Payloads { + total *= len(values) + } + + return total +} diff --git a/mining/pattern_generation_test.go b/mining/pattern_generation_test.go new file mode 100644 index 00000000..49f64039 --- /dev/null +++ b/mining/pattern_generation_test.go @@ -0,0 +1,582 @@ +package mining + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGeneratePattern(t *testing.T) { + tests := []struct { + name string + subdomains []string + expectedPattern string + expectedPayloadLen int + checkPayloads map[string][]string + }{ + { + name: "simple static pattern", + subdomains: []string{"api", "api", "api"}, + expectedPattern: "api", + expectedPayloadLen: 0, + }, + { + name: "single level with variable", + subdomains: []string{"api-prod", "api-staging"}, + expectedPattern: "api{{p0}}", + expectedPayloadLen: 1, + checkPayloads: map[string][]string{ + "p0": {"-prod", "-staging"}, + }, + }, + { + name: "single level with number variation", + subdomains: []string{"api-1", "api-2", "api-3"}, + expectedPattern: "api{{p0}}", + expectedPayloadLen: 1, + checkPayloads: map[string][]string{ + "p0": {"-1", "-2", "-3"}, + }, + }, + { + name: "complex single level", + subdomains: []string{"api-prod-1", "api-prod-2", "api-staging-1"}, + expectedPattern: "api{{p0}}{{p1}}", + expectedPayloadLen: 2, + checkPayloads: map[string][]string{ + "p0": {"-prod", "-staging"}, + "p1": {"-1", "-2"}, + }, + }, + { + name: "multi-level simple", + subdomains: []string{"api.dev", "api.prod"}, + expectedPattern: "api.{{p0}}", + expectedPayloadLen: 1, + checkPayloads: map[string][]string{ + "p0": {"dev", "prod"}, + }, + }, + { + name: "multi-level complex", + subdomains: []string{"api-1.dev", "api-2.dev", "api-1.prod"}, + expectedPattern: "api{{p0}}.{{p1}}", + expectedPayloadLen: 2, + checkPayloads: map[string][]string{ + "p0": {"-1", "-2"}, + "p1": {"dev", "prod"}, + }, + }, + { + name: "with numbers", + subdomains: []string{"web01", "web02", "web03"}, + expectedPattern: "web{{p0}}", + expectedPayloadLen: 1, + checkPayloads: map[string][]string{ + "p0": {"01", "02", "03"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a pattern miner instance + domains := make([]string, len(tt.subdomains)) + for i, sub := range tt.subdomains { + domains[i] = sub + ".example.com" + } + + pm, err := NewPatternMiner(domains, &Options{ + MinLDist: 2, + MaxLDist: 10, + }) + require.NoError(t, err, "Failed to create PatternMiner") + + // Generate pattern + pattern, err := pm.generatePattern(tt.subdomains) + require.NoError(t, err, "generatePattern() should not return error") + require.NotNil(t, pattern, "generatePattern() should not return nil pattern") + + // Check pattern string + assert.Equal(t, tt.expectedPattern, pattern.Pattern, "Pattern mismatch") + + // Check payload count + assert.Len(t, pattern.Payloads, tt.expectedPayloadLen, "Payload count mismatch") + + // Check specific payloads if provided + if tt.checkPayloads != nil { + for varName, expectedValues := range tt.checkPayloads { + actualValues, ok := pattern.Payloads[varName] + require.True(t, ok, "Payload %q not found in result", varName) + + // Check if all expected values are present + for _, expectedVal := range expectedValues { + assert.Contains(t, actualValues, expectedVal, "Payload %q missing expected value %q", varName, expectedVal) + } + } + } + }) + } +} + +func TestAnalyzeTokenAlignment(t *testing.T) { + tests := []struct { + name string + subdomains []string + expectedLevels int + checkLevel0 func(t *testing.T, lp LevelPosition) + }{ + { + name: "static single level", + subdomains: []string{"api", "api", "api"}, + expectedLevels: 1, + checkLevel0: func(t *testing.T, lp LevelPosition) { + assert.Len(t, lp.Positions, 1, "Expected 1 position") + assert.Equal(t, TokenPositionStatic, lp.Positions[0].Type, "Expected static token") + }, + }, + { + name: "variable single level", + subdomains: []string{"api-prod", "api-staging"}, + expectedLevels: 1, + checkLevel0: func(t *testing.T, lp LevelPosition) { + assert.Len(t, lp.Positions, 2, "Expected 2 positions") + // First token "api" should be static + assert.Equal(t, TokenPositionStatic, lp.Positions[0].Type, "Expected first token to be static") + // Second token should be variable + assert.Equal(t, TokenPositionVariable, lp.Positions[1].Type, "Expected second token to be variable") + }, + }, + { + name: "multi-level", + subdomains: []string{"api.dev", "api.prod"}, + expectedLevels: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create pattern miner + domains := make([]string, len(tt.subdomains)) + for i, sub := range tt.subdomains { + domains[i] = sub + ".example.com" + } + + pm, err := NewPatternMiner(domains, &Options{ + MinLDist: 2, + MaxLDist: 10, + }) + require.NoError(t, err, "Failed to create PatternMiner") + + // Tokenize + tokenized := Tokenize(tt.subdomains) + + // Analyze alignment + positions := pm.analyzeTokenAlignment(tokenized) + + // Check level count + assert.Len(t, positions, tt.expectedLevels, "Level count mismatch") + + // Run custom checks if provided + if tt.checkLevel0 != nil && len(positions) > 0 { + tt.checkLevel0(t, positions[0]) + } + }) + } +} + +func TestBuildDSLPattern(t *testing.T) { + tests := []struct { + name string + levelPositions []LevelPosition + expectedPattern string + }{ + { + name: "single static level", + levelPositions: []LevelPosition{ + { + LevelIndex: 0, + Positions: []TokenPosition{ + {Index: 0, Type: TokenPositionStatic, Values: []string{"api"}}, + }, + }, + }, + expectedPattern: "api", + }, + { + name: "single level with variable", + levelPositions: []LevelPosition{ + { + LevelIndex: 0, + Positions: []TokenPosition{ + {Index: 0, Type: TokenPositionStatic, Values: []string{"api"}}, + {Index: 1, Type: TokenPositionVariable, VarName: "p0", Values: []string{"-prod", "-staging"}}, + }, + }, + }, + expectedPattern: "api{{p0}}", + }, + { + name: "multi-level pattern", + levelPositions: []LevelPosition{ + { + LevelIndex: 0, + Positions: []TokenPosition{ + {Index: 0, Type: TokenPositionStatic, Values: []string{"api"}}, + }, + }, + { + LevelIndex: 1, + Positions: []TokenPosition{ + {Index: 0, Type: TokenPositionVariable, VarName: "p0", Values: []string{"dev", "prod"}}, + }, + }, + }, + expectedPattern: "api.{{p0}}", + }, + { + name: "complex pattern", + levelPositions: []LevelPosition{ + { + LevelIndex: 0, + Positions: []TokenPosition{ + {Index: 0, Type: TokenPositionStatic, Values: []string{"api"}}, + {Index: 1, Type: TokenPositionVariable, VarName: "p0", Values: []string{"-1", "-2"}}, + }, + }, + { + LevelIndex: 1, + Positions: []TokenPosition{ + {Index: 0, Type: TokenPositionVariable, VarName: "p1", Values: []string{"dev", "prod"}}, + }, + }, + }, + expectedPattern: "api{{p0}}.{{p1}}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a dummy pattern miner + pm := &PatternMiner{} + + pattern := pm.buildDSLPattern(tt.levelPositions) + + assert.Equal(t, tt.expectedPattern, pattern, "Pattern mismatch") + }) + } +} + +func TestExtractPayloads(t *testing.T) { + tests := []struct { + name string + levelPositions []LevelPosition + subdomains []string + expectedPayloads map[string]int // varName -> count of unique values + checkContains map[string][]string // varName -> values that must be present + }{ + { + name: "single variable", + levelPositions: []LevelPosition{ + { + LevelIndex: 0, + Positions: []TokenPosition{ + {Index: 0, Type: TokenPositionStatic, Values: []string{"api"}}, + {Index: 1, Type: TokenPositionVariable, VarName: "p0", Values: []string{"-prod", "-staging"}}, + }, + }, + }, + subdomains: []string{"api-prod", "api-staging"}, + expectedPayloads: map[string]int{ + "p0": 2, + }, + checkContains: map[string][]string{ + "p0": {"-prod", "-staging"}, + }, + }, + { + name: "multiple variables", + levelPositions: []LevelPosition{ + { + LevelIndex: 0, + Positions: []TokenPosition{ + {Index: 0, Type: TokenPositionStatic, Values: []string{"api"}}, + {Index: 1, Type: TokenPositionVariable, VarName: "p0", Values: []string{"-prod", "-staging"}}, + {Index: 2, Type: TokenPositionVariable, VarName: "p1", Values: []string{"-1", "-2"}}, + }, + }, + }, + subdomains: []string{"api-prod-1", "api-staging-2"}, + expectedPayloads: map[string]int{ + "p0": 2, + "p1": 2, + }, + checkContains: map[string][]string{ + "p0": {"-prod", "-staging"}, + "p1": {"-1", "-2"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm := &PatternMiner{} + tokenized := Tokenize(tt.subdomains) + + payloads := pm.extractPayloads(tt.levelPositions, tokenized) + + // Check payload count + assert.Len(t, payloads, len(tt.expectedPayloads), "Payload count mismatch") + + // Check each payload + for varName, expectedCount := range tt.expectedPayloads { + values, ok := payloads[varName] + require.True(t, ok, "Payload %q not found", varName) + assert.Len(t, values, expectedCount, "Payload %q value count mismatch", varName) + } + + // Check specific values + if tt.checkContains != nil { + for varName, expectedValues := range tt.checkContains { + actualValues, ok := payloads[varName] + require.True(t, ok, "Payload %q not found", varName) + + for _, expectedVal := range expectedValues { + assert.Contains(t, actualValues, expectedVal, "Payload %q missing value %q", varName, expectedVal) + } + } + } + }) + } +} + +func TestCalculateCombinations(t *testing.T) { + tests := []struct { + name string + pattern *DSLPattern + expected int + }{ + { + name: "static pattern", + pattern: &DSLPattern{ + Pattern: "api", + Payloads: map[string][]string{}, + }, + expected: 1, + }, + { + name: "single variable", + pattern: &DSLPattern{ + Pattern: "api{{p0}}", + Payloads: map[string][]string{ + "p0": {"-prod", "-staging"}, + }, + }, + expected: 2, + }, + { + name: "two variables", + pattern: &DSLPattern{ + Pattern: "api{{p0}}.{{p1}}", + Payloads: map[string][]string{ + "p0": {"-prod", "-staging", "-dev"}, + "p1": {"us", "eu"}, + }, + }, + expected: 6, // 3 × 2 + }, + { + name: "three variables", + pattern: &DSLPattern{ + Pattern: "{{p0}}{{p1}}.{{p2}}", + Payloads: map[string][]string{ + "p0": {"api", "web"}, + "p1": {"-1", "-2"}, + "p2": {"dev", "prod", "staging"}, + }, + }, + expected: 12, // 2 × 2 × 3 + }, + { + name: "optional position with empty string", + pattern: &DSLPattern{ + Pattern: "api{{p0}}", + Payloads: map[string][]string{ + "p0": {"-prod", ""}, + }, + }, + expected: 2, // Generates: api-prod, api + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm := &PatternMiner{} + result := pm.calculateCombinations(tt.pattern) + assert.Equal(t, tt.expected, result, "Combination count mismatch") + }) + } +} + +func TestIsGoodPattern(t *testing.T) { + tests := []struct { + name string + pattern *DSLPattern + nkeys int + threshold float64 + maxRatio float64 + expected bool + reason string + }{ + { + name: "passes absolute threshold", + pattern: &DSLPattern{ + Pattern: "api{{p0}}", + Payloads: map[string][]string{ + "p0": {"-prod", "-staging"}, + }, + }, + nkeys: 2, + threshold: 100, + maxRatio: 10, + expected: true, + reason: "2 combinations < 100 threshold", + }, + { + name: "passes ratio check", + pattern: &DSLPattern{ + Pattern: "api{{p0}}.{{p1}}", + Payloads: map[string][]string{ + "p0": {"-prod", "-staging"}, + "p1": {"dev", "staging"}, + }, + }, + nkeys: 2, + threshold: 2, // 4 combinations > 2 threshold + maxRatio: 5, // but ratio 4/2 = 2.0 < 5 + expected: true, + reason: "ratio 2.0 < 5.0 max_ratio", + }, + { + name: "fails both checks - too generic", + pattern: &DSLPattern{ + Pattern: "{{p0}}{{p1}}.{{p2}}", + Payloads: map[string][]string{ + "p0": {"api", "web", "app"}, + "p1": {"-1", "-2", "-3", "-4"}, + "p2": {"dev", "prod", "staging"}, + }, + }, + nkeys: 3, + threshold: 10, // 36 combinations > 10 threshold + maxRatio: 5, // ratio 36/3 = 12.0 > 5 max_ratio + expected: false, + reason: "36 combinations exceeds threshold and ratio 12.0 exceeds max_ratio", + }, + { + name: "static pattern passes", + pattern: &DSLPattern{ + Pattern: "api", + Payloads: map[string][]string{}, + }, + nkeys: 1, + threshold: 2, // 1 < 2, passes + maxRatio: 1, + expected: true, + reason: "1 combination < 2 threshold", + }, + { + name: "no thresholds configured - accepts all", + pattern: &DSLPattern{ + Pattern: "{{p0}}{{p1}}{{p2}}", + Payloads: map[string][]string{ + "p0": {"1", "2", "3", "4", "5"}, + "p1": {"a", "b", "c"}, + "p2": {"x", "y"}, + }, + }, + nkeys: 2, + threshold: 0, // disabled + maxRatio: 0, // disabled + expected: true, + reason: "no thresholds configured", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm := &PatternMiner{ + options: &Options{ + PatternThreshold: tt.threshold, + PatternQualityRatio: tt.maxRatio, + }, + } + + result := pm.isGoodPattern(tt.pattern, tt.nkeys) + assert.Equal(t, tt.expected, result, tt.reason) + }) + } +} + +func TestGeneratePatternWithQualityCheck(t *testing.T) { + tests := []struct { + name string + subdomains []string + threshold float64 + maxRatio float64 + expectNil bool + reason string + }{ + { + name: "good pattern - accepted", + subdomains: []string{"api-prod", "api-staging"}, + threshold: 100, + maxRatio: 10, + expectNil: false, + reason: "pattern generates 2 combinations, well within limits", + }, + { + name: "too generic - rejected", + subdomains: []string{"a", "b", "c"}, + threshold: 1, // very strict + maxRatio: 0.5, // very strict ratio + expectNil: true, + reason: "pattern would be too generic and gets rejected", + }, + { + name: "no thresholds - always accepts", + subdomains: []string{"api-1", "api-2", "api-3"}, + threshold: 0, // disabled + maxRatio: 0, // disabled + expectNil: false, + reason: "no quality checks when thresholds disabled", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + domains := make([]string, len(tt.subdomains)) + for i, sub := range tt.subdomains { + domains[i] = sub + ".example.com" + } + + pm, err := NewPatternMiner(domains, &Options{ + MinLDist: 2, + MaxLDist: 10, + PatternThreshold: tt.threshold, + PatternQualityRatio: tt.maxRatio, + }) + require.NoError(t, err) + + pattern, err := pm.generatePattern(tt.subdomains) + require.NoError(t, err) + + if tt.expectNil { + assert.Nil(t, pattern, tt.reason) + } else { + assert.NotNil(t, pattern, tt.reason) + } + }) + } +} diff --git a/mining/pm.go b/mining/pm.go new file mode 100644 index 00000000..e247bc74 --- /dev/null +++ b/mining/pm.go @@ -0,0 +1,173 @@ +package mining + +/// Jargons & Definitions +/// for api5.dev.example.com +// subdomain/subdomainpart = api5.dev +// root/ root domain = example.com +// level0 = api5 +// level1 = dev +// for level 0 , section1 = api , section2 = 5 + +import ( + "strings" + + "github.com/armon/go-radix" + levenshtein "github.com/ka-weihe/fast-levenshtein" + "github.com/projectdiscovery/utils/errkit" + mapsutil "github.com/projectdiscovery/utils/maps" + "golang.org/x/net/publicsuffix" +) + +var ( + ErrNoDomains = errkit.New("no domains provided to mine patterns") +) + +type Options struct { + // MinLDist is the minimum levenshtein distance for clustering + MinLDist int + // MaxLDist is the maximum levenshtein distance for clustering + MaxLDist int + // PatternThreshold is the threshold after which pattern will be discarded + // because of being too generic + PatternThreshold float64 + // PatternQualityRatio is the ratio of output/input patterns + // after generating patterns from a cluster it is used to discard low quality patterns + // whose ratio is greater than this threshold + PatternQualityRatio float64 + // MaxPatternLength is the maximum length of generated pattern string + // patterns exceeding this length are discarded + MaxPatternLength int + // NgramsLimit limits the number of ngrams processed in hierarchical clustering + // If 0, all ngrams are processed. This matches Python's ngrams_limit parameter. + NgramsLimit int +} + +func (o *Options) applyDefaults() { + // reference from regulator + if o.MinLDist == 0 { + o.MinLDist = 2 + } + if o.MaxLDist == 0 { + o.MaxLDist = 10 + } +} + +// PatternMiner is the main struct for pattern mining +// it mines for patterns for the given domains +type PatternMiner struct { + rootDomains []string + subdomains []string + trie *radix.Tree // radix tree for fast prefix searches + distanceMap map[Edge]int // contains distance betwen two nodes or items + options *Options + results []*DSLPattern // collected patterns that passed quality checks + seenPatterns map[string]struct{} // deduplication: tracks pattern strings already generated +} + +// NewPatternMiner creates a new pattern miner instance +func NewPatternMiner(domains []string, opts *Options) (*PatternMiner, error) { + if len(domains) == 0 { + return nil, ErrNoDomains + } + opts.applyDefaults() + p := &PatternMiner{ + distanceMap: make(map[Edge]int), + options: opts, + trie: radix.New(), + results: make([]*DSLPattern, 0), + seenPatterns: make(map[string]struct{}), + } + if err := p.prepare(domains); err != nil { + return nil, err + } + return p, nil +} + +func (p *PatternMiner) prepare(domains []string) error { + var subs = make(map[string]struct{}) + var rootDomains = make(map[string]struct{}) + for _, domain := range domains { + rootDomain, err := publicsuffix.EffectiveTLDPlusOne(domain) + if err != nil { + return err + } + if _, ok := rootDomains[rootDomain]; !ok { + rootDomains[rootDomain] = struct{}{} + } + sub := strings.TrimSuffix(domain, "."+rootDomain) + if _, ok := subs[sub]; !ok { + subs[sub] = struct{}{} + } + } + p.rootDomains = mapsutil.GetKeys(rootDomains) + p.subdomains = mapsutil.GetKeys(subs) + + // build radix tree for fast prefix searches + // this tree is used to do fast lookup of all subdomains with a given prefix + // ex: prefix "ap" will return api, api1, app, etc. + for k := range subs { + p.trie.Insert(k, nil) // value is nil since we only need to track keys + } + + // distance map + // calculate levenshtein distance between all subdomains + // ex: distance between api and api1 is 1 + // while distance between api and apple is 3 + for _, x := range p.subdomains { + for _, y := range p.subdomains { + if x == y { + continue + } + // get a predicatable edgename between subdomains + edge := NewEdge(x, y) + if _, ok := p.distanceMap[edge]; !ok { + p.distanceMap[edge] = levenshtein.Distance(x, y) + } + } + } + return nil +} + +// GetResults returns all patterns that were generated and passed quality checks. +// This should be called after Execute() completes. +func (p *PatternMiner) GetResults() []*DSLPattern { + return p.results +} + +// Execute mines for patterns from all existing data +func (p *PatternMiner) Execute() error { + // The core idea of the algorithm is to group or cluster subdomains + // into a set of unique subdomains that might be related in some way + // for each such group execute GeneratePattern() method to generate patterns + // from that group + + // to generate high quality patterns we cluster subdomains using many + // clustering approaches and generate patterns from each group and combine them + // when generating pattern, we purge low quality patterns by using a ratio of input/output patterns + + // Approaches used + + // 1) Levenshtein Distance on Subdomain Part Clustering + if err := p.levenshteinSubsClustering(); err != nil { + return err + } + + // 2) Hierarchical Ngram-Based Clustering + // This approach uses a multi-level hierarchy that combines: + // - Unigram/Bigram Prefix Clustering + // - Full Token Prefix Matching (extract first token, then cluster by that prefix) + // - Levenshtein Distance on Prefixes Clustering (edit distance on prefix-matched subsets) + // + // Flow: ngram → keys → generate pattern (Chance 1) + // → extract prefixes → for each prefix: + // → get keys → generate pattern (Chance 2) + // → edit distance clustering → patterns (Chance 3) + // + // This hierarchical approach provides multiple chances to generate patterns at different + // levels of granularity, resulting in comprehensive pattern mining. + if err := p.hierarchicalNgramClustering(); err != nil { + return err + } + + return nil +} diff --git a/mining/pm_test.go b/mining/pm_test.go new file mode 100644 index 00000000..01e4e87d --- /dev/null +++ b/mining/pm_test.go @@ -0,0 +1,299 @@ +package mining + +import ( + "sort" + "testing" + + levenshtein "github.com/ka-weihe/fast-levenshtein" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper function to create a PatternMiner with test data +func createTestPatternMiner(subdomains []string) *PatternMiner { + p := &PatternMiner{ + subdomains: subdomains, + distanceMap: make(map[Edge]int), + options: &Options{ + MinLDist: 2, + MaxLDist: 10, + }, + } + + // Calculate levenshtein distance between all subdomains + for _, x := range p.subdomains { + for _, y := range p.subdomains { + if x == y { + continue + } + edge := NewEdge(x, y) + if _, ok := p.distanceMap[edge]; !ok { + p.distanceMap[edge] = levenshtein.Distance(x, y) + } + } + } + + return p +} + +// Helper function to sort clusters for consistent comparison +func sortClusters(clusters [][]string) { + for _, cluster := range clusters { + sort.Strings(cluster) + } + sort.Slice(clusters, func(i, j int) bool { + if len(clusters[i]) != len(clusters[j]) { + return len(clusters[i]) > len(clusters[j]) + } + return clusters[i][0] < clusters[j][0] + }) +} + +// Helper function to check if two cluster sets are equal +func clustersEqual(a, b [][]string) bool { + if len(a) != len(b) { + return false + } + sortClusters(a) + sortClusters(b) + + for i := range a { + if len(a[i]) != len(b[i]) { + return false + } + for j := range a[i] { + if a[i][j] != b[i][j] { + return false + } + } + } + return true +} + +// Test empty input +func TestGetClustersByLevenshteinDistance_Empty(t *testing.T) { + p := createTestPatternMiner([]string{}) + clusters, err := p.getClustersByLevenshteinDistance(1) + + require.NoError(t, err) + assert.Nil(t, clusters) +} + +// Test single subdomain - should return empty (no non-singleton clusters) +func TestGetClustersByLevenshteinDistance_SingleSubdomain(t *testing.T) { + p := createTestPatternMiner([]string{"api"}) + clusters, err := p.getClustersByLevenshteinDistance(1) + + require.NoError(t, err) + assert.Empty(t, clusters, "Single subdomain should not create any clusters") +} + +// Test subdomains that are far apart - no clusters should form +func TestGetClustersByLevenshteinDistance_NoSimilarSubdomains(t *testing.T) { + p := createTestPatternMiner([]string{"api", "website", "dashboard"}) + + // Verify distances are > 1 + assert.Greater(t, p.distanceMap[NewEdge("api", "website")], 1) + assert.Greater(t, p.distanceMap[NewEdge("api", "dashboard")], 1) + + clusters, err := p.getClustersByLevenshteinDistance(1) + + require.NoError(t, err) + assert.Empty(t, clusters, "Subdomains with distance >= k should not cluster") +} + +// Test simple pair with distance 1 +func TestGetClustersByLevenshteinDistance_SimplePair(t *testing.T) { + // "api" and "api1" have distance 1 + p := createTestPatternMiner([]string{"api", "api1"}) + + // Verify distance + assert.Equal(t, 1, p.distanceMap[NewEdge("api", "api1")]) + + clusters, err := p.getClustersByLevenshteinDistance(2) // k=2, so dist 1 < 2 + + require.NoError(t, err) + require.Len(t, clusters, 1, "Should create exactly one cluster") + + expected := [][]string{{"api", "api1"}} + assert.True(t, clustersEqual(clusters, expected)) +} + +// Test distance boundary - dist < k (not <=) +func TestGetClustersByLevenshteinDistance_StrictLessThan(t *testing.T) { + // "api" and "api12" have distance 2 + p := createTestPatternMiner([]string{"api", "api12"}) + + // Verify distance + assert.Equal(t, 2, p.distanceMap[NewEdge("api", "api12")]) + + // With k=2, distance 2 is NOT < 2, so should not cluster + clusters, err := p.getClustersByLevenshteinDistance(2) + + require.NoError(t, err) + assert.Empty(t, clusters, "Distance = k should NOT cluster (requires dist < k)") + + // With k=3, distance 2 < 3, so should cluster + clusters, err = p.getClustersByLevenshteinDistance(3) + + require.NoError(t, err) + require.Len(t, clusters, 1) + expected := [][]string{{"api", "api12"}} + assert.True(t, clustersEqual(clusters, expected)) +} + +// Test NON-transitive behavior - key example from documentation +func TestGetClustersByLevenshteinDistance_NonTransitive(t *testing.T) { + // Example 2 from documentation: + // Items: {A, B, C} where A↔B=1, B↔C=1, A↔C=3 + // With k=2: + // - Center A: {A, B} (B dist 1 < 2, C dist 3 ≮ 2) + // - Center B: {A, B, C} (A dist 1 < 2, C dist 1 < 2) + // - Center C: {B, C} (B dist 1 < 2, A dist 3 ≮ 2) + // Result: [{A, B, C}, {A, B}, {B, C}] + + // Using "api", "api1", "api12" to match this pattern + p := createTestPatternMiner([]string{"api", "api1", "api12"}) + + // Verify distances match the pattern + assert.Equal(t, 1, p.distanceMap[NewEdge("api", "api1")]) + assert.Equal(t, 1, p.distanceMap[NewEdge("api1", "api12")]) + assert.Equal(t, 2, p.distanceMap[NewEdge("api", "api12")]) + + clusters, err := p.getClustersByLevenshteinDistance(2) + + require.NoError(t, err) + + // Should get 3 distinct clusters + require.Len(t, clusters, 3, "Should create three distinct clusters") + + expected := [][]string{ + {"api", "api1", "api12"}, // Center: api1 + {"api", "api1"}, // Center: api + {"api1", "api12"}, // Center: api12 + } + assert.True(t, clustersEqual(clusters, expected)) +} + +// Test multiple separate clusters +func TestGetClustersByLevenshteinDistance_MultipleClusters(t *testing.T) { + // Two separate cluster groups (need to use more distant subdomains) + // Cluster group 1: api, api1, api2 + // Cluster group 2: web, web1 + subdomains := []string{"api", "api1", "api2", "web", "web1"} + p := createTestPatternMiner(subdomains) + + // Verify api and web are far apart + assert.Greater(t, p.distanceMap[NewEdge("api", "web")], 2) + + clusters, err := p.getClustersByLevenshteinDistance(2) + + require.NoError(t, err) + require.Greater(t, len(clusters), 0, "Should create at least one cluster") + + // All clusters should have more than 1 item + for _, cluster := range clusters { + assert.Greater(t, len(cluster), 1, "Each cluster should have more than 1 item") + } +} + +// Test deduplication behavior +func TestGetClustersByLevenshteinDistance_Deduplication(t *testing.T) { + // "api" and "api1" both have dist 1 + // Both will generate the same cluster {api, api1} + // Should deduplicate to just one cluster + p := createTestPatternMiner([]string{"api", "api1"}) + + clusters, err := p.getClustersByLevenshteinDistance(2) + + require.NoError(t, err) + require.Len(t, clusters, 1, "Should deduplicate identical clusters") + + expected := [][]string{{"api", "api1"}} + assert.True(t, clustersEqual(clusters, expected)) +} + +// Test real-world subdomain patterns +func TestGetClustersByLevenshteinDistance_RealWorld(t *testing.T) { + subdomains := []string{ + "api", "api-v1", "api-v2", + "staging", "staging-api", "staging-web", + "prod", "prod-api", "prod-web", + "dev", "dev-api", + } + p := createTestPatternMiner(subdomains) + + // Test with different thresholds + for k := 3; k <= 5; k++ { + clusters, err := p.getClustersByLevenshteinDistance(k) + require.NoError(t, err) + + t.Logf("k=%d: found %d clusters", k, len(clusters)) + for i, cluster := range clusters { + t.Logf(" Cluster %d (size=%d): %v", i, len(cluster), cluster) + assert.GreaterOrEqual(t, len(cluster), 2, "Each cluster should have at least 2 items") + } + } +} + +// Test overlapping clusters example from documentation +func TestGetClustersByLevenshteinDistance_OverlappingExample(t *testing.T) { + // Example 3 from documentation (simplified): + // Need items where we get overlapping clusters + // Using "aa", "aaa", "aaaa" to get progressive distances + subdomains := []string{"aa", "aaa", "aaaa", "aaaaa"} + p := createTestPatternMiner(subdomains) + + // Print distances for verification + for i := 0; i < len(subdomains); i++ { + for j := i + 1; j < len(subdomains); j++ { + edge := NewEdge(subdomains[i], subdomains[j]) + t.Logf("Distance %s ↔ %s = %d", subdomains[i], subdomains[j], p.distanceMap[edge]) + } + } + + clusters, err := p.getClustersByLevenshteinDistance(2) + + require.NoError(t, err) + t.Logf("Found %d clusters:", len(clusters)) + for i, cluster := range clusters { + t.Logf(" Cluster %d: %v", i, cluster) + } + + // Each cluster should have at least 2 items + for _, cluster := range clusters { + assert.GreaterOrEqual(t, len(cluster), 2) + } +} + +// Test Edge creation consistency +func TestGetClustersByLevenshteinDistance_EdgeConsistency(t *testing.T) { + subdomains := []string{"abc", "abd", "xyz"} + p := createTestPatternMiner(subdomains) + + // Verify Edge creates consistent keys regardless of order + edge1 := NewEdge("abc", "abd") + edge2 := NewEdge("abd", "abc") + assert.Equal(t, edge1, edge2, "Edge should be order-independent") + + // Verify distance is stored correctly + dist, ok := p.distanceMap[edge1] + assert.True(t, ok, "Distance should be stored for edge") + assert.Equal(t, 1, dist, "Distance between 'abc' and 'abd' should be 1") +} + +// Benchmark the clustering algorithm +func BenchmarkGetClustersByLevenshteinDistance(b *testing.B) { + subdomains := []string{ + "api", "api1", "api2", "api3", "api-v1", "api-v2", + "web", "web1", "web2", "webapp", "website", + "app", "app1", "app2", "mobile", "mobile-app", + "dev", "dev-api", "staging", "staging-api", "prod", "prod-api", + } + p := createTestPatternMiner(subdomains) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = p.getClustersByLevenshteinDistance(3) + } +} diff --git a/mining/tokenization.go b/mining/tokenization.go new file mode 100644 index 00000000..92df68d5 --- /dev/null +++ b/mining/tokenization.go @@ -0,0 +1,214 @@ +package mining + +import ( + "regexp" + "strings" +) + +// TokenizedSubdomain represents a tokenized subdomain with hierarchical structure. +type TokenizedSubdomain struct { + // Original is the original subdomain string that was tokenized + Original string + // Levels contains the tokenized levels of the subdomain hierarchy + Levels []Level +} + +// Level represents a single level in the subdomain hierarchy with its tokens. +// For example, in "api-prod-12.dev", there are two levels: +// - Level 0: {Label: "api-prod-12", Tokens: ["api", "-prod", "-12"]} +// - Level 1: {Label: "dev", Tokens: ["dev"]} +type Level struct { + // Label is the original label at this level (e.g., "api-prod-12") + Label string + // Tokens are the individual tokens extracted from the label + Tokens []string +} + +// Tokenize converts subdomains into structured tokenized representations. +// It splits subdomains by dots into hierarchical levels and tokenizes each level +// by hyphens and numbers while preserving hyphen prefixes. +// +// NOTE: Input should be subdomain parts only (root domain already removed). +// For example: "api-prod-12" or "api.dev" (not "api.dev.example.com") +// +// EXAMPLE: +// +// Input: ["api-prod-12", "web", "api5.dev-staging2"] +// Output: []TokenizedSubdomain{ +// { +// Original: "api-prod-12", +// Levels: []Level{ +// {Label: "api-prod-12", Tokens: []string{"api", "-prod", "-12"}}, +// }, +// }, +// { +// Original: "web", +// Levels: []Level{ +// {Label: "web", Tokens: []string{"web"}}, +// }, +// }, +// { +// Original: "api5.dev-staging2", +// Levels: []Level{ +// {Label: "api5", Tokens: []string{"api", "5"}}, +// {Label: "dev-staging2", Tokens: []string{"dev", "-staging", "2"}}, +// }, +// }, +// } +// +// ALGORITHM: +// 1. Split subdomain by '.' to get hierarchical levels +// 2. For each level, tokenize by hyphens and numbers: +// - Split by '-' and prefix subsequent parts with '-' +// - Further split by numbers using regex +// - Special case: merge standalone '-' with following numbers +// +// This preserves the structure needed for pattern mining and clustering. +func Tokenize(subdomains []string) []TokenizedSubdomain { + result := make([]TokenizedSubdomain, 0, len(subdomains)) + + for _, subdomain := range subdomains { + tokenized := TokenizedSubdomain{ + Original: subdomain, + Levels: []Level{}, // Initialize to empty slice, not nil + } + + // Handle empty subdomains + if subdomain == "" { + result = append(result, tokenized) + continue + } + + // Split subdomain by '.' to get hierarchical labels + labels := strings.Split(subdomain, ".") + tokenized.Levels = make([]Level, 0, len(labels)) + + for _, label := range labels { + if label == "" { + continue + } + level := Level{ + Label: label, + Tokens: tokenizeLabel(label), + } + tokenized.Levels = append(tokenized.Levels, level) + } + + result = append(result, tokenized) + } + + return result +} + +// tokenizeLabel tokenizes a single label by splitting on hyphens and numbers. +// +// ALGORITHM: +// 1. Split by '-' and prefix subsequent parts with '-' +// 2. Split each part by numbers (e.g., "api12" → ["api", "12"]) +// 3. Handle special case: standalone '-' followed by number becomes '-number' +// +// EXAMPLE: +// +// "api-prod-12" → ["api", "-prod", "-12"] +// "web01" → ["web", "01"] +// "foo-12" → ["foo", "-12"] +func tokenizeLabel(label string) []string { + tokens := make([]string, 0) + + // Split by hyphens and prefix subsequent parts with '-' + hyphenParts := strings.Split(label, "-") + for i, part := range hyphenParts { + if part == "" { + continue + } + + // Prefix with '-' for all parts except the first + if i != 0 { + part = "-" + part + } + + // Split by numbers using regex + subtokens := splitByNumbers(part) + + // Handle special case: merge standalone '-' with following number + // This happens when we have patterns like "foo-12" + filtered := make([]string, 0, len(subtokens)) + for j, subtoken := range subtokens { + if subtoken == "-" && j+1 < len(subtokens) { + // If next token exists, merge with it + if j+1 < len(subtokens) { + subtokens[j+1] = "-" + subtokens[j+1] + } + } else { + filtered = append(filtered, subtoken) + } + } + + tokens = append(tokens, filtered...) + } + + return tokens +} + +// numberSplitRegex is used to split strings by numeric sequences +var numberSplitRegex = regexp.MustCompile(`([0-9]+)`) + +// splitByNumbers splits a string by numeric sequences while keeping the numbers. +// +// EXAMPLE: +// +// "api12web34" → ["api", "12", "web", "34"] +// "prod" → ["prod"] +// "123" → ["123"] +func splitByNumbers(s string) []string { + // Use regex to split by numbers but keep them in the result + parts := numberSplitRegex.Split(s, -1) + numbers := numberSplitRegex.FindAllString(s, -1) + + result := make([]string, 0, len(parts)+len(numbers)) + + // Interleave parts and numbers + numIndex := 0 + for i, part := range parts { + if part != "" { + result = append(result, part) + } + // Add the corresponding number if it exists + if i < len(parts)-1 && numIndex < len(numbers) { + result = append(result, numbers[numIndex]) + numIndex++ + } + } + + return result +} + +// extractFirstToken extracts the first token from a subdomain string. +// This is used for prefix-based clustering in the pattern mining algorithm. +// +// EXAMPLE: +// +// "api-prod-1" → "api" +// "web.dev" → "web" +// "api5" → "api" +func (p *PatternMiner) extractFirstToken(subdomain string) string { + if subdomain == "" { + return "" + } + + // Split by '.' to get the first level + parts := strings.Split(subdomain, ".") + if len(parts) == 0 { + return "" + } + + // Tokenize the first level + tokens := tokenizeLabel(parts[0]) + if len(tokens) == 0 { + return "" + } + + // Return the first token, removing any hyphen prefix + firstToken := tokens[0] + return strings.TrimPrefix(firstToken, "-") +} diff --git a/mining/tokenization_test.go b/mining/tokenization_test.go new file mode 100644 index 00000000..47c112b5 --- /dev/null +++ b/mining/tokenization_test.go @@ -0,0 +1,474 @@ +package mining + +import ( + "reflect" + "testing" +) + +func TestTokenize(t *testing.T) { + tests := []struct { + name string + input []string + expected []TokenizedSubdomain + }{ + { + name: "simple subdomain with hyphen and number", + input: []string{"api-prod-12"}, + expected: []TokenizedSubdomain{ + { + Original: "api-prod-12", + Levels: []Level{ + {Label: "api-prod-12", Tokens: []string{"api", "-prod", "-12"}}, + }, + }, + }, + }, + { + name: "single word subdomain", + input: []string{"web"}, + expected: []TokenizedSubdomain{ + { + Original: "web", + Levels: []Level{ + {Label: "web", Tokens: []string{"web"}}, + }, + }, + }, + }, + { + name: "multi-level subdomain", + input: []string{"api.dev"}, + expected: []TokenizedSubdomain{ + { + Original: "api.dev", + Levels: []Level{ + {Label: "api", Tokens: []string{"api"}}, + {Label: "dev", Tokens: []string{"dev"}}, + }, + }, + }, + }, + { + name: "hyphenated number", + input: []string{"foo-12"}, + expected: []TokenizedSubdomain{ + { + Original: "foo-12", + Levels: []Level{ + {Label: "foo-12", Tokens: []string{"foo", "-12"}}, + }, + }, + }, + }, + { + name: "alphanumeric without hyphen", + input: []string{"web01"}, + expected: []TokenizedSubdomain{ + { + Original: "web01", + Levels: []Level{ + {Label: "web01", Tokens: []string{"web", "01"}}, + }, + }, + }, + }, + { + name: "complex subdomain with numbers", + input: []string{"api5-dev-staging2"}, + expected: []TokenizedSubdomain{ + { + Original: "api5-dev-staging2", + Levels: []Level{ + {Label: "api5-dev-staging2", Tokens: []string{"api", "5", "-dev", "-staging", "2"}}, + }, + }, + }, + }, + { + name: "multiple subdomains", + input: []string{"api", "web"}, + expected: []TokenizedSubdomain{ + { + Original: "api", + Levels: []Level{ + {Label: "api", Tokens: []string{"api"}}, + }, + }, + { + Original: "web", + Levels: []Level{ + {Label: "web", Tokens: []string{"web"}}, + }, + }, + }, + }, + { + name: "empty subdomain", + input: []string{""}, + expected: []TokenizedSubdomain{ + { + Original: "", + Levels: []Level{}, + }, + }, + }, + { + name: "multiple hyphens", + input: []string{"api-v1-prod-us-west"}, + expected: []TokenizedSubdomain{ + { + Original: "api-v1-prod-us-west", + Levels: []Level{ + {Label: "api-v1-prod-us-west", Tokens: []string{"api", "-v", "1", "-prod", "-us", "-west"}}, + }, + }, + }, + }, + { + name: "numbers at start", + input: []string{"123api"}, + expected: []TokenizedSubdomain{ + { + Original: "123api", + Levels: []Level{ + {Label: "123api", Tokens: []string{"123", "api"}}, + }, + }, + }, + }, + { + name: "multi-level with complex tokens", + input: []string{"api5.dev-staging2"}, + expected: []TokenizedSubdomain{ + { + Original: "api5.dev-staging2", + Levels: []Level{ + {Label: "api5", Tokens: []string{"api", "5"}}, + {Label: "dev-staging2", Tokens: []string{"dev", "-staging", "2"}}, + }, + }, + }, + }, + { + name: "consecutive numbers", + input: []string{"api123456"}, + expected: []TokenizedSubdomain{ + { + Original: "api123456", + Levels: []Level{ + {Label: "api123456", Tokens: []string{"api", "123456"}}, + }, + }, + }, + }, + { + name: "hyphen at multiple positions", + input: []string{"prod-web-api-v2"}, + expected: []TokenizedSubdomain{ + { + Original: "prod-web-api-v2", + Levels: []Level{ + {Label: "prod-web-api-v2", Tokens: []string{"prod", "-web", "-api", "-v", "2"}}, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Tokenize(tt.input) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("Tokenize() = %+v, want %+v", result, tt.expected) + } + }) + } +} + +func TestTokenizeLabel(t *testing.T) { + tests := []struct { + name string + label string + expected []string + }{ + { + name: "simple word", + label: "api", + expected: []string{"api"}, + }, + { + name: "word with number", + label: "api5", + expected: []string{"api", "5"}, + }, + { + name: "hyphenated words", + label: "api-prod", + expected: []string{"api", "-prod"}, + }, + { + name: "hyphenated with number", + label: "api-prod-12", + expected: []string{"api", "-prod", "-12"}, + }, + { + name: "word hyphen number", + label: "foo-12", + expected: []string{"foo", "-12"}, + }, + { + name: "number only", + label: "123", + expected: []string{"123"}, + }, + { + name: "mixed alphanumeric", + label: "web01test99", + expected: []string{"web", "01", "test", "99"}, + }, + { + name: "multiple hyphens", + label: "api-v1-prod", + expected: []string{"api", "-v", "1", "-prod"}, + }, + { + name: "hyphen number hyphen word", + label: "test-123-prod", + expected: []string{"test", "-123", "-prod"}, + }, + { + name: "consecutive numbers", + label: "api123456test", + expected: []string{"api", "123456", "test"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tokenizeLabel(tt.label) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("tokenizeLabel(%q) = %v, want %v", tt.label, result, tt.expected) + } + }) + } +} + +func TestSplitByNumbers(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "word with number", + input: "api12", + expected: []string{"api", "12"}, + }, + { + name: "word only", + input: "prod", + expected: []string{"prod"}, + }, + { + name: "number only", + input: "123", + expected: []string{"123"}, + }, + { + name: "multiple numbers", + input: "api12web34", + expected: []string{"api", "12", "web", "34"}, + }, + { + name: "number at start", + input: "123api", + expected: []string{"123", "api"}, + }, + { + name: "hyphen prefix with number", + input: "-prod12", + expected: []string{"-prod", "12"}, + }, + { + name: "consecutive numbers", + input: "test123456", + expected: []string{"test", "123456"}, + }, + { + name: "empty string", + input: "", + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := splitByNumbers(tt.input) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("splitByNumbers(%q) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +// Benchmark tests +func BenchmarkTokenize(b *testing.B) { + input := []string{ + "api-prod-12", + "web01.staging", + "api5-dev-us-west-1", + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + Tokenize(input) + } +} + +func BenchmarkTokenizeLabel(b *testing.B) { + label := "api-prod-12-staging" + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tokenizeLabel(label) + } +} + +func BenchmarkSplitByNumbers(b *testing.B) { + input := "api12web34test56" + + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitByNumbers(input) + } +} + +// Table-driven test for edge cases +func TestTokenizeEdgeCases(t *testing.T) { + tests := []struct { + name string + input []string + expected []TokenizedSubdomain + }{ + { + name: "empty input", + input: []string{}, + expected: []TokenizedSubdomain{}, + }, + { + name: "empty string in array", + input: []string{""}, + expected: []TokenizedSubdomain{ + { + Original: "", + Levels: []Level{}, + }, + }, + }, + { + name: "multiple dots in subdomain", + input: []string{"a.b.c.d"}, + expected: []TokenizedSubdomain{ + { + Original: "a.b.c.d", + Levels: []Level{ + {Label: "a", Tokens: []string{"a"}}, + {Label: "b", Tokens: []string{"b"}}, + {Label: "c", Tokens: []string{"c"}}, + {Label: "d", Tokens: []string{"d"}}, + }, + }, + }, + }, + { + name: "special characters with hyphens", + input: []string{"api-v2-beta-3"}, + expected: []TokenizedSubdomain{ + { + Original: "api-v2-beta-3", + Levels: []Level{ + {Label: "api-v2-beta-3", Tokens: []string{"api", "-v", "2", "-beta", "-3"}}, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Tokenize(tt.input) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("Tokenize() = %+v, want %+v", result, tt.expected) + } + }) + } +} + +func TestExtractFirstToken(t *testing.T) { + // Create a simple PatternMiner instance for testing + domains := []string{"api.example.com", "web.example.com"} + opts := &Options{ + MinLDist: 2, + MaxLDist: 10, + } + pm, err := NewPatternMiner(domains, opts) + if err != nil { + t.Fatalf("Failed to create PatternMiner: %v", err) + } + + tests := []struct { + name string + subdomain string + expected string + }{ + { + name: "simple word", + subdomain: "api", + expected: "api", + }, + { + name: "hyphenated words", + subdomain: "api-prod", + expected: "api", + }, + { + name: "with number", + subdomain: "api5", + expected: "api", + }, + { + name: "multiple levels", + subdomain: "api.dev", + expected: "api", + }, + { + name: "complex with hyphens and numbers", + subdomain: "api-prod-12", + expected: "api", + }, + { + name: "starts with number", + subdomain: "123api", + expected: "123", + }, + { + name: "empty string", + subdomain: "", + expected: "", + }, + { + name: "multi-level complex", + subdomain: "api5-dev.staging", + expected: "api", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pm.extractFirstToken(tt.subdomain) + if result != tt.expected { + t.Errorf("extractFirstToken(%q) = %q, want %q", tt.subdomain, result, tt.expected) + } + }) + } +} diff --git a/mining/utils.go b/mining/utils.go new file mode 100644 index 00000000..b699713f --- /dev/null +++ b/mining/utils.go @@ -0,0 +1,80 @@ +package mining + +// Edge represents a connection between two nodes or items +// in our case it is connection between two subdomains +type Edge [2]string + +func NewEdge(sub1, sub2 string) Edge { + if sub1 > sub2 { + return Edge{sub2, sub1} + } + return Edge{sub1, sub2} +} + +// Helper function to check if two clusters (as maps) are equal +func clustersEqual_internal(a, b map[string]struct{}) bool { + if len(a) != len(b) { + return false + } + for k := range a { + if _, ok := b[k]; !ok { + return false + } + } + return true +} + +// GenerateValidNgrams generates all valid unigrams and bigrams that can be used as +// PREFIX patterns for subdomain matching according to RFC 1123 rules. +// +// PURPOSE: +// These ngrams are used to find subdomains that START with specific patterns. +// For example, to find all subdomains starting with "a-", "ap", "1-", etc. +// +// RFC 1123 PREFIX RULES: +// - Valid characters: a-z, A-Z, 0-9, hyphen (-) +// - MUST start with: letter or digit (RFC 1123 requirement) +// - Second character can be: letter, digit, OR hyphen +// +// EXAMPLES: +// +// Valid prefixes: a, z, 0, 9, ab, a1, 1a, a-, 0-, api, web- +// Invalid prefixes: -, -a, -0 (cannot start with hyphen) +// +// USE CASES: +// +// "a" matches: api.com, app.com, about.com +// "a-" matches: a-one.com, a-test.com, a-api.com +// "ab" matches: about.com, abc.com, abstract.com +// +// RETURNS: +// +// unigrams: All valid single-character prefixes (a-z, A-Z, 0-9) +// bigrams: All valid two-character prefixes following RFC rules +func GenerateValidNgrams() (unigrams []string, bigrams []string) { + // Valid characters for subdomain labels + letters := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + digits := "0123456789" + hyphen := "-" + + // Combine all valid characters + allChars := letters + digits + hyphen + // Characters that can start a subdomain (no hyphen) + startChars := letters + digits + + // Generate all valid unigrams (single characters that can start a subdomain) + // Hyphen cannot start a subdomain label + for _, c := range startChars { + unigrams = append(unigrams, string(c)) + } + + // Generate all valid bigrams (two-character prefixes) + // Rule: Must start with letter/digit, second char can be anything valid + for _, first := range startChars { + for _, second := range allChars { + bigrams = append(bigrams, string(first)+string(second)) + } + } + + return unigrams, bigrams +} diff --git a/mining/utils_test.go b/mining/utils_test.go new file mode 100644 index 00000000..25b6e3da --- /dev/null +++ b/mining/utils_test.go @@ -0,0 +1,183 @@ +package mining + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateValidNgrams(t *testing.T) { + unigrams, bigrams := GenerateValidNgrams() + + t.Run("Unigrams", func(t *testing.T) { + // Should have 52 letters (a-z, A-Z) + 10 digits (0-9) = 62 unigrams + require.Len(t, unigrams, 62, "Should have 62 valid unigrams") + + // Check that all expected characters are present + expectedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + for _, c := range expectedChars { + assert.Contains(t, unigrams, string(c), "Should contain %s", string(c)) + } + + // Verify no hyphen in unigrams (hyphen cannot be standalone) + assert.NotContains(t, unigrams, "-", "Hyphen should not be a valid unigram") + + // Verify all are single character + for _, u := range unigrams { + assert.Len(t, u, 1, "Unigram should be single character: %s", u) + } + }) + + t.Run("Bigrams", func(t *testing.T) { + // Should have 62 * 63 = 3,906 bigrams + // (62 start chars) * (63 second chars: 62 alphanumeric + 1 hyphen) + require.Len(t, bigrams, 62*63, "Should have 3,906 valid bigrams") + + // Verify all are two characters + for _, b := range bigrams { + assert.Len(t, b, 2, "Bigram should be two characters: %s", b) + } + + // Check valid PREFIX examples are present + validPrefixes := []string{ + "aa", "ab", "a1", "1a", "0z", + "AA", "Ab", "A1", "1A", "0Z", + "a-", "1-", "z-", // Valid prefixes for subdomains like a-one.com + } + + for _, prefix := range validPrefixes { + assert.Contains(t, bigrams, prefix, "Valid prefix %s should be present", prefix) + } + + // Check invalid prefixes are NOT present + invalidPrefixes := []string{ + "-a", "-1", "--", // Cannot start with hyphen + } + + for _, prefix := range invalidPrefixes { + assert.NotContains(t, bigrams, prefix, "Invalid prefix %s should not be present", prefix) + } + }) + + t.Run("RFC_Compliance", func(t *testing.T) { + // Test that no ngram starts with hyphen (RFC 1123 requirement) + allNgrams := append(unigrams, bigrams...) + + for _, ngram := range allNgrams { + assert.False(t, strings.HasPrefix(ngram, "-"), + "Ngram should not start with hyphen: %s", ngram) + } + }) + + t.Run("Bigrams_With_Hyphen", func(t *testing.T) { + // Bigrams CAN end with hyphen as they are PREFIXES + // For example, "a-" is a valid prefix for "a-one.com" + hyphensFound := 0 + for _, b := range bigrams { + if strings.HasSuffix(b, "-") { + hyphensFound++ + } + } + // Should have 62 bigrams ending with hyphen (one for each start char) + assert.Equal(t, 62, hyphensFound, "Should have 62 bigrams ending with hyphen") + }) + + t.Run("Character_Validation", func(t *testing.T) { + validChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-" + allNgrams := append(unigrams, bigrams...) + + for _, ngram := range allNgrams { + for _, c := range ngram { + assert.Contains(t, validChars, string(c), + "Ngram %s contains invalid character: %c", ngram, c) + } + } + }) +} + +func TestGenerateValidNgrams_Uniqueness(t *testing.T) { + unigrams, bigrams := GenerateValidNgrams() + + t.Run("Unigrams_Unique", func(t *testing.T) { + seen := make(map[string]bool) + for _, u := range unigrams { + assert.False(t, seen[u], "Duplicate unigram found: %s", u) + seen[u] = true + } + }) + + t.Run("Bigrams_Unique", func(t *testing.T) { + seen := make(map[string]bool) + for _, b := range bigrams { + assert.False(t, seen[b], "Duplicate bigram found: %s", b) + seen[b] = true + } + }) +} + +func TestGenerateValidNgrams_Examples(t *testing.T) { + unigrams, bigrams := GenerateValidNgrams() + + testCases := []struct { + name string + ngram string + isValid bool + isUnigram bool + }{ + // Valid unigrams + {"lowercase letter", "a", true, true}, + {"uppercase letter", "Z", true, true}, + {"digit", "5", true, true}, + + // Invalid unigrams + {"hyphen alone", "-", false, true}, + + // Valid bigrams (prefixes) + {"two letters", "ab", true, false}, + {"letter then digit", "a1", true, false}, + {"digit then letter", "1a", true, false}, + {"two digits", "99", true, false}, + {"mixed case", "Aa", true, false}, + {"letter then hyphen", "a-", true, false}, // Valid prefix for a-one.com + {"digit then hyphen", "1-", true, false}, // Valid prefix for 1-api.com + + // Invalid bigrams + {"starts with hyphen", "-a", false, false}, + {"two hyphens", "--", false, false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var found bool + if tc.isUnigram { + for _, u := range unigrams { + if u == tc.ngram { + found = true + break + } + } + } else { + for _, b := range bigrams { + if b == tc.ngram { + found = true + break + } + } + } + + if tc.isValid { + assert.True(t, found, "Expected valid ngram %s to be present", tc.ngram) + } else { + assert.False(t, found, "Expected invalid ngram %s to be absent", tc.ngram) + } + }) + } +} + +func BenchmarkGenerateValidNgrams(b *testing.B) { + for i := 0; i < b.N; i++ { + GenerateValidNgrams() + } +} diff --git a/mutator.go b/mutator.go index 16ddd94f..e45500c7 100644 --- a/mutator.go +++ b/mutator.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/projectdiscovery/alterx/mining" "github.com/projectdiscovery/fasttemplate" "github.com/projectdiscovery/gologger" "github.com/projectdiscovery/utils/dedupe" @@ -40,6 +41,17 @@ type Options struct { Enrich bool // MaxSize limits output data size MaxSize int + // Mode specifies which patterns to use: "default", "discover", or "both" + // - "default": use user-specified or default patterns only (default if not specified) + // - "discover": use mined patterns only (no defaults) + // - "both": combine mined patterns with defaults for maximum coverage + Mode string + // Mining/Discovery options (used when Mode="discover" or Mode="both") + MinLDist int // minimum levenshtein distance for clustering + MaxLDist int // maximum levenshtein distance for clustering + PatternThreshold int // threshold for filtering low-quality patterns + PatternQualityRatio int // pattern quality ratio threshold + NgramsLimit int // limit number of n-grams to process (0 = all) } // Mutator @@ -52,32 +64,162 @@ type Mutator struct { maxkeyLenInBytes int } +// createMiningOptions creates mining options with defaults applied +func createMiningOptions(opts *Options) *mining.Options { + miningOpts := &mining.Options{ + MinLDist: opts.MinLDist, + MaxLDist: opts.MaxLDist, + PatternThreshold: float64(opts.PatternThreshold), + PatternQualityRatio: float64(opts.PatternQualityRatio), + NgramsLimit: opts.NgramsLimit, + } + + // Apply defaults if not set (matching Python regulator defaults) + if miningOpts.MinLDist == 0 { + miningOpts.MinLDist = 2 + } + if miningOpts.MaxLDist == 0 { + miningOpts.MaxLDist = 5 + } + if miningOpts.PatternThreshold == 0 { + miningOpts.PatternThreshold = 500 // Python default + } + if miningOpts.PatternQualityRatio == 0 { + miningOpts.PatternQualityRatio = 25 // Python default: max_ratio = 25.0 + } + + return miningOpts +} + // New creates and returns new mutator instance from options func New(opts *Options) (*Mutator, error) { if len(opts.Domains) == 0 { return nil, fmt.Errorf("no input provided to calculate permutations") } - if len(opts.Payloads) == 0 { - opts.Payloads = map[string][]string{} - if len(DefaultConfig.Payloads) == 0 { - return nil, fmt.Errorf("something went wrong, `DefaultWordList` and input wordlist are empty") - } - opts.Payloads = DefaultConfig.Payloads + + // Determine mode - default to "default" if not explicitly set + mode := opts.Mode + if mode == "" { + mode = "default" // use default patterns if mode not specified } - if len(opts.Patterns) == 0 { - if len(DefaultConfig.Patterns) == 0 { - return nil, fmt.Errorf("something went wrong,`DefaultPatters` and input patterns are empty") - } - opts.Patterns = DefaultConfig.Patterns + + // Validate mode + validModes := map[string]bool{"default": true, "discover": true, "both": true} + if !validModes[mode] { + return nil, fmt.Errorf("invalid mode '%s': must be 'default', 'discover', or 'both'", mode) } - // purge duplicates if any - for k, v := range opts.Payloads { - dedupe := sliceutil.Dedupe(v) - if len(v) != len(dedupe) { - gologger.Warning().Msgf("%v duplicate payloads found in %v. purging them..", len(v)-len(dedupe), k) - opts.Payloads[k] = dedupe + + // Create appropriate pattern provider(s) based on mode + var patterns []string + var payloads map[string][]string + + switch mode { + case "discover": + // Discover mode: mine patterns only (no defaults) + if len(opts.Domains) < 10 { + gologger.Warning().Msgf("discover mode performance may be degraded with less than 10 domains") + } + + miningOpts := createMiningOptions(opts) + provider := NewMinedPatternProvider(opts.Domains, miningOpts) + + var err error + patterns, payloads, err = provider.GetPatterns() + if err != nil { + return nil, fmt.Errorf("failed to mine patterns: %w", err) + } + + case "default": + // Default mode: use user-specified or default patterns + defaultPatterns := opts.Patterns + defaultPayloads := opts.Payloads + + // Apply defaults if not provided + if len(defaultPayloads) == 0 { + if len(DefaultConfig.Payloads) == 0 { + return nil, fmt.Errorf("something went wrong, `DefaultWordList` and input wordlist are empty") + } + defaultPayloads = DefaultConfig.Payloads + } + if len(defaultPatterns) == 0 { + if len(DefaultConfig.Patterns) == 0 { + return nil, fmt.Errorf("something went wrong, `DefaultConfig` and input Pattern are empty") + } + defaultPatterns = DefaultConfig.Patterns } + + // Deduplicate payloads + for k, v := range defaultPayloads { + dedupe := sliceutil.Dedupe(v) + if len(v) != len(dedupe) { + gologger.Warning().Msgf("%v duplicate payloads found in %v. purging them..", len(v)-len(dedupe), k) + defaultPayloads[k] = dedupe + } + } + + patterns = defaultPatterns + payloads = defaultPayloads + + case "both": + // Both mode: combine mined patterns with defaults + if len(opts.Domains) < 10 { + gologger.Warning().Msgf("discover mode performance may be degraded with less than 10 domains") + } + + // Get mined patterns + miningOpts := createMiningOptions(opts) + minedProvider := NewMinedPatternProvider(opts.Domains, miningOpts) + minedPatterns, minedPayloads, err := minedProvider.GetPatterns() + if err != nil { + return nil, fmt.Errorf("failed to mine patterns: %w", err) + } + + // Get default patterns + defaultPatterns := opts.Patterns + defaultPayloads := opts.Payloads + if len(defaultPayloads) == 0 { + defaultPayloads = DefaultConfig.Payloads + } + if len(defaultPatterns) == 0 { + defaultPatterns = DefaultConfig.Patterns + } + + // Combine patterns (mined + default, deduplicated) + patternSet := make(map[string]struct{}) + for _, p := range minedPatterns { + patternSet[p] = struct{}{} + } + for _, p := range defaultPatterns { + patternSet[p] = struct{}{} + } + patterns = make([]string, 0, len(patternSet)) + for p := range patternSet { + patterns = append(patterns, p) + } + + // Combine payloads (merge maps) + payloads = make(map[string][]string) + for k, v := range minedPayloads { + payloads[k] = v + } + for k, v := range defaultPayloads { + if existing, ok := payloads[k]; ok { + // Merge and deduplicate + combined := append(existing, v...) + payloads[k] = sliceutil.Dedupe(combined) + } else { + payloads[k] = v + } + } + + gologger.Info().Msgf("Combined mode: %d mined + %d default = %d total patterns", + len(minedPatterns), len(defaultPatterns), len(patterns)) } + + // Update options with final patterns and payloads + opts.Patterns = patterns + opts.Payloads = payloads + m := &Mutator{ Options: opts, } @@ -90,6 +232,7 @@ func New(opts *Options) (*Mutator, error) { if opts.Enrich { m.enrichPayloads() } + return m, nil } diff --git a/pattern_provider.go b/pattern_provider.go new file mode 100644 index 00000000..3d44dd4c --- /dev/null +++ b/pattern_provider.go @@ -0,0 +1,116 @@ +package alterx + +import ( + "fmt" + "strings" + + "github.com/projectdiscovery/alterx/mining" + "github.com/projectdiscovery/gologger" +) + +// PatternProvider defines the interface for pattern generation strategies. +// Implementations provide patterns and payloads that can be used by the Mutator +// to generate domain permutations. +type PatternProvider interface { + // GetPatterns returns the patterns and their associated payloads. + // Returns: + // - patterns: slice of pattern strings in DSL format (e.g., "api-{{p0}}.{{root}}") + // - payloads: map of payload variables to their values (e.g., {"p0": ["prod", "dev"]}) + // - error: any error encountered during pattern generation + GetPatterns() (patterns []string, payloads map[string][]string, err error) +} + +// ManualPatternProvider provides user-specified patterns and payloads. +// This is the default mode where users explicitly provide patterns and wordlists. +type ManualPatternProvider struct { + patterns []string + payloads map[string][]string +} + +// NewManualPatternProvider creates a new manual pattern provider. +func NewManualPatternProvider(patterns []string, payloads map[string][]string) *ManualPatternProvider { + return &ManualPatternProvider{ + patterns: patterns, + payloads: payloads, + } +} + +// GetPatterns returns the manually specified patterns and payloads. +func (m *ManualPatternProvider) GetPatterns() ([]string, map[string][]string, error) { + if len(m.patterns) == 0 { + return nil, nil, fmt.Errorf("no patterns provided") + } + return m.patterns, m.payloads, nil +} + +// MinedPatternProvider discovers patterns from input domains using pattern mining algorithms. +// This mode automatically generates patterns by analyzing the structure of provided domains. +type MinedPatternProvider struct { + domains []string + miningOptions *mining.Options +} + +// NewMinedPatternProvider creates a new mined pattern provider. +func NewMinedPatternProvider(domains []string, opts *mining.Options) *MinedPatternProvider { + return &MinedPatternProvider{ + domains: domains, + miningOptions: opts, + } +} + +// GetPatterns mines patterns from the input domains and returns them in mutator format. +func (m *MinedPatternProvider) GetPatterns() ([]string, map[string][]string, error) { + gologger.Info().Msgf("Mining patterns from %d domains...", len(m.domains)) + + // Create pattern miner + pm, err := mining.NewPatternMiner(m.domains, m.miningOptions) + if err != nil { + return nil, nil, fmt.Errorf("failed to create pattern miner: %w", err) + } + + // Execute pattern mining (runs both Levenshtein and Hierarchical clustering) + if err := pm.Execute(); err != nil { + return nil, nil, fmt.Errorf("pattern mining failed: %w", err) + } + + // Get mined DSL patterns + dslPatterns := pm.GetResults() + if len(dslPatterns) == 0 { + return nil, nil, fmt.Errorf("no patterns discovered from input domains") + } + + gologger.Info().Msgf("Discovered %d patterns", len(dslPatterns)) + + // Convert DSLPatterns to mutator format with UNIQUE payload keys per pattern + // This prevents cross-contamination when multiple patterns use same key names (e.g., "p0") + // + // BEFORE (BUGGY): Pattern1="api{{p0}}", Pattern2="web{{p0}}" + // Merged payloads = {"p0": ["-prod", "-staging", "-dev", "-test"]} + // Result: BOTH patterns use ALL 4 values → 2x explosion + // + // AFTER (FIXED): Pattern1="api{{p0_0}}", Pattern2="web{{p0_1}}" + // Payloads = {"p0_0": ["-prod", "-staging"], "p0_1": ["-dev", "-test"]} + // Result: Each pattern uses only its own values → correct output + patterns := make([]string, 0, len(dslPatterns)) + allPayloads := make(map[string][]string) + + for patternIdx, dsl := range dslPatterns { + // Create unique payload keys by appending pattern index + // Original: "{{p0}}{{p1}}" → Unique: "{{p0_0}}{{p1_0}}" + uniquePattern := dsl.Pattern + for key := range dsl.Payloads { + uniqueKey := fmt.Sprintf("%s_%d", key, patternIdx) + // Replace old key with unique key in pattern + uniquePattern = strings.ReplaceAll(uniquePattern, "{{"+key+"}}", "{{"+uniqueKey+"}}") + // Store payloads with unique key + allPayloads[uniqueKey] = append([]string{}, dsl.Payloads[key]...) + } + + // Add pattern with unique keys (append .{{root}} to match mutator expectations) + patterns = append(patterns, uniquePattern+".{{root}}") + } + + gologger.Info().Msgf("Generated %d unique payload keys", len(allPayloads)) + + return patterns, allPayloads, nil +} diff --git a/pattern_provider_test.go b/pattern_provider_test.go new file mode 100644 index 00000000..e1ee5c74 --- /dev/null +++ b/pattern_provider_test.go @@ -0,0 +1,159 @@ +package alterx + +import ( + "testing" + + "github.com/projectdiscovery/alterx/mining" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestManualPatternProvider(t *testing.T) { + patterns := []string{"{{word}}.{{root}}", "{{word}}-{{number}}.{{root}}"} + payloads := map[string][]string{ + "word": {"api", "dev", "staging"}, + "number": {"1", "2", "3"}, + } + + provider := NewManualPatternProvider(patterns, payloads) + require.NotNil(t, provider) + + gotPatterns, gotPayloads, err := provider.GetPatterns() + require.NoError(t, err) + + assert.Equal(t, patterns, gotPatterns) + assert.Equal(t, payloads, gotPayloads) +} + +func TestManualPatternProvider_EmptyPatterns(t *testing.T) { + provider := NewManualPatternProvider([]string{}, map[string][]string{}) + + _, _, err := provider.GetPatterns() + assert.Error(t, err) + assert.Contains(t, err.Error(), "no patterns provided") +} + +func TestMinedPatternProvider(t *testing.T) { + // Use simple test domains + domains := []string{ + "api-prod.example.com", + "api-staging.example.com", + "web-prod.example.com", + "web-staging.example.com", + } + + miningOpts := &mining.Options{ + MinLDist: 2, + MaxLDist: 5, + PatternThreshold: 1000, + PatternQualityRatio: 100, + } + + provider := NewMinedPatternProvider(domains, miningOpts) + require.NotNil(t, provider) + + patterns, payloads, err := provider.GetPatterns() + require.NoError(t, err) + + // Should generate at least some patterns + assert.Greater(t, len(patterns), 0, "Should generate at least one pattern") + + // Patterns should end with .{{root}} + for _, pattern := range patterns { + assert.Contains(t, pattern, "{{", "Pattern should contain DSL syntax") + } + + // Should have some payloads + assert.Greater(t, len(payloads), 0, "Should generate at least one payload key") + + t.Logf("Generated %d patterns", len(patterns)) + t.Logf("Generated %d payload keys", len(payloads)) + + // Log a few patterns for inspection + for i, pattern := range patterns { + if i >= 5 { + t.Logf("... and %d more patterns", len(patterns)-5) + break + } + t.Logf(" Pattern %d: %s", i+1, pattern) + } +} + +func TestMinedPatternProvider_InsufficientDomains(t *testing.T) { + // Test with very few domains + domains := []string{ + "api.example.com", + } + + miningOpts := &mining.Options{ + MinLDist: 2, + MaxLDist: 5, + } + + provider := NewMinedPatternProvider(domains, miningOpts) + patterns, payloads, err := provider.GetPatterns() + + // Should handle gracefully (might return error or simple patterns) + if err != nil { + t.Logf("Expected behavior: %v", err) + } else { + t.Logf("Generated %d patterns from single domain", len(patterns)) + t.Logf("Payloads: %v", payloads) + } +} + +func TestMutatorIntegration_ManualMode(t *testing.T) { + opts := &Options{ + Domains: []string{"example.com"}, + Patterns: []string{"{{word}}.{{root}}"}, + Payloads: map[string][]string{ + "word": {"api", "dev"}, + }, + Mode: "default", + } + + mutator, err := New(opts) + require.NoError(t, err) + require.NotNil(t, mutator) + + assert.Equal(t, 1, len(mutator.Options.Patterns)) + assert.Equal(t, 2, len(mutator.Options.Payloads["word"])) +} + +func TestMutatorIntegration_DiscoverMode(t *testing.T) { + opts := &Options{ + Domains: []string{ + "api-prod.example.com", + "api-staging.example.com", + "web-prod.example.com", + "web-staging.example.com", + "db-primary.example.com", + "db-secondary.example.com", + "cache-1.example.com", + "cache-2.example.com", + "app-v1.example.com", + "app-v2.example.com", + }, + Mode: "discover", + } + + mutator, err := New(opts) + require.NoError(t, err) + require.NotNil(t, mutator) + + // Should have discovered patterns and payloads + assert.Greater(t, len(mutator.Options.Patterns), 0, "Should discover patterns") + assert.Greater(t, len(mutator.Options.Payloads), 0, "Should discover payloads") + + t.Logf("Discovered %d patterns", len(mutator.Options.Patterns)) + t.Logf("Discovered %d payload keys", len(mutator.Options.Payloads)) + + // Show some patterns + for i, pattern := range mutator.Options.Patterns { + if i >= 5 { + t.Logf("... and %d more patterns", len(mutator.Options.Patterns)-5) + break + } + t.Logf(" Pattern %d: %s", i+1, pattern) + } +} diff --git a/util.go b/util.go index ece16517..20984b16 100644 --- a/util.go +++ b/util.go @@ -7,7 +7,7 @@ import ( "unsafe" ) -var varRegex = regexp.MustCompile(`\{\{([a-zA-Z0-9]+)\}\}`) +var varRegex = regexp.MustCompile(`\{\{([a-zA-Z0-9_]+)\}\}`) // returns no of variables present in statement func getVarCount(data string) int {