This repository contains an advanced, multi-stage pipeline for multi-label stuttering event detection using metric learning and retrieval-augmented neural networks.
The pipeline combines two expert systems:
- Expert A: Pure audio processing via a Conformer encoder
- Expert B: Semantic retrieval using metric-learned embeddings and k-nearest neighbors
A learnable gating mechanism fuses predictions from both experts, resulting in a system that achieves strong performance on the SEP-28k dataset with interpretable decision-making through retrieval-based explanations.
Stuttering is a complex speech disorder where disfluencies (e.g., blocks, prolongations, repetitions, sound/word repetitions, interjections) frequently co-occur. This makes it a natural multi-label classification problem. This project uses a strict speaker-independent protocol to prevent data leakage and ensure real-world applicability.
Rather than relying on a single model, we employ an ensemble architecture that combines two complementary experts:
- Expert A (Audio Conformer): Processes raw Wav2Vec2 features through a Conformer encoder with convolutional and attention modules for direct acoustic modeling.
- Expert B (Semantic Retrieval): Learns a metric space via Multi-Label Supervised Contrastive Loss, enabling semantic k-NN retrieval. This expert predicts based on the labels and similarity scores of the k-nearest neighbors in embedding space.
A learnable gating mechanism dynamically combines both expert predictions, automatically learning when to trust audio-only signals vs. retrieval-based evidence.
-
Multi-Label Supervised Contrastive Learning: A custom loss function that trains embeddings based on Jaccard similarity of label vectors, creating a semantic space where clips with overlapping stuttering types cluster together.
-
Retrieval-Augmented Architecture: Uses Faiss-indexed embeddings for efficient k-NN search. The query clip's features are combined with neighbor labels and similarity scores for joint reasoning.
-
Ensemble Gating Strategy: Learnable gates automatically balance audio-only and retrieval-based predictions, enabling adaptive decision-making for different utterance types.
-
Built-in Interpretability: Retrieval decisions are naturally interpretable—predictions can be explained by showing which training samples the model "retrieved" for comparison.
The pipeline is trained on the SEP-28k dataset, a large, multi-label stuttering event annotation corpus. After filtering out "NoStutter"-only clips, the task involves classifying the 29,783 clips that contain at least one disfluency.
- Prolongation: Sound prolonged within a word
- Block: Complete stop or blockage during speech initiation
- SoundRep: Repetition of a single sound
- WordRep: Repetition of an entire word
- Interjection: Insertion of filler words or sounds
| Class | Clips with Label |
|---|---|
| Block | 8,081 |
| Interjection | 5,934 |
| Prolongation | 5,629 |
| SoundRep | 3,486 |
| WordRep | 2,759 |
Over 56% of all disfluent clips contain two or more co-occurring stuttering events:
| Labels per Clip | Count | Percentage |
|---|---|---|
| 1 | 12,975 | 43.57% |
| 2 | 10,964 | 36.81% |
| 3 | 4,746 | 15.94% |
| 4 | 1,014 | 3.40% |
| 5 | 84 | 0.28% |
The system consists of four main stages, executed on TPU hardware for scalability:
We present two complementary architectures:
Figure 1: Mid-Fusion RAC (Retrieval Augmented Conformer) - Cross-attention-based fusion within the model:

Figure 2: Late-Fusion via Gated Experts - Post-prediction gating mechanism:

- Speaker-Independent Split: Train/validation/test sets are split such that no speaker appears in multiple sets, ensuring speaker-independent generalization.
- Instance-Balanced Augmentation: Underrepresented label combinations are oversampled using
audiomentationsto create a balanced training distribution (~10.5k samples per class). - Wav2Vec2 Feature Extraction: All audio clips are processed through a frozen
facebook/wav2vec2-base-960hmodel, producing compressed feature arrays of shape(150 frames, 1024 dims). - Output:
.npzfiles containing feature matrices and one-hot encoded label vectors.
This stage trains a metric-learning embedder that groups clips by label similarity:
- Architecture: BiGRU encoder (256 hidden units) → Dense projection (1024 dims) → Attention pooling → L2 normalization
- Loss Function: Multi-Label Supervised Contrastive Loss
- Computes cosine similarity matrix between embeddings
- Uses Jaccard similarity of label vectors to weight the loss
- Only samples with overlapping labels contribute to the contrastive signal
- Numerical stability via log-sum-exp tricks
- Training: 50 epochs with cosine decay learning rate, early stopping on Recall@K metric
- Output: Trained embedder weights and embeddings for all training samples
- FAISS Indexing: Embeddings from Stage 2 are added to a FAISS IndexFlatIP (inner product search with normalized vectors = cosine similarity)
- k-NN Search (k=5): For each query clip (train/val/test), retrieve the k most similar training samples
- Neighbor Features: Gather neighbor embeddings, ground-truth labels, and similarity scores into structured tensors
- Input: Wav2Vec2 features (150 × 1024)
- Architecture:
- 2 Conformer blocks with spatial dropouts and Gaussian noise for regularization
- Each block contains: Conv module (depthwise convolutions) + Multi-head attention + Feed-forward network
- Global average pooling
- Dense layers: 256 dims → output (5 classes)
- Loss: Binary cross-entropy with label smoothing
- Training: 30 epochs with AdamW optimizer
- Input: k neighbor embeddings (5 × 1024)
- Architecture:
- Dense projection: 1024 → 256 dims with swish activation
- Global average pooling across neighbors
- Dense layers: 256 → 128 dims
- Output: 5 classes with sigmoid activation
- Loss: Binary cross-entropy
- Training: 30 epochs with AdamW optimizer (higher learning rate for faster convergence)
- Extract intermediate features from both experts (256 dims from A, 128 dims from B)
- Trust Gate: Learnable sigmoid layer that outputs a [0, 1] weight vector (128 dims)
- Weight = 1 → trust retrieval evidence
- Weight = 0 → rely on audio-only signal
- Gated Retrieval: Element-wise multiply gate with retrieval features
- Final Classifier: Concatenate audio + gated retrieval features → Dense layers → output
- Training: 20 epochs on combined audio + neighbor vectors
All results are from the test set using speaker-independent evaluation.
The pure audio model achieves strong baseline performance:
- Binary Accuracy: ~68%
- ROC-AUC: ~0.82
- PR-AUC: ~0.76
The ensemble with gating mechanism provides improved predictions:
- Automatic threshold optimization per class via F1-score grid search
- Adaptive weighting between audio and retrieval evidence
- Per-class results show consistent improvements on imbalanced classes
The notebook includes t-SNE visualizations showing:
- Raw Features vs. Learned Embeddings: Comparison of baseline Wav2Vec2 pooling vs. metric-learned space
- Class Separation: Clear clustering of similar stuttering types in the learned embedding space
- Multi-Label Interactions: Top 12 label co-occurrence combinations visualized by location in embedding space
- Per-Class Decomposition: 5-subplot grid showing where each stuttering type concentrates in the embedding space
.
├── README.md # This file
├── stut-vecs/
│ ├── stut-vec11.ipynb # Main training & evaluation notebook (TPU-optimized)
│ ├── stut-vec7.ipynb # Earlier retrieval-augmented experiments
│ └── [other exploratory notebooks]
├── Dataset/
│ ├── SEP-28k_episodes.csv
│ ├── SEP-28k_labels.csv
│ ├── KSoF_release/ # Kassel State of Fluency auxiliary dataset
│ └── FluencyBank/
├── Notebooks/ # Organized by feature/architecture
│ ├── mfcc/ # MFCC-based baselines
│ ├── wav2vec/ # Wav2Vec2 experiments
│ └── Oth/ # Other exploration
└── [training outputs, weights, indices]
Main Implementation: All model training and evaluation code is contained in stut-vecs/stut-vec11.ipynb. This notebook:
- Initializes TPU strategy for distributed training
- Loads and preprocesses data
- Trains the metric-learning embedder with MultiLabelSupConLoss
- Builds FAISS index for k-NN retrieval
- Trains Expert A (Audio Conformer)
- Trains Expert B (Neural Retrieval)
- Trains fusion model with learnable gating
- Evaluates all components and generates visualizations
- tensorflow ≥ 2.18.0 (with TPU support for distributed training)
- tensorflow-tpu ≥ 2.18.0 (if using TPU)
- faiss-cpu or faiss-gpu (for approximate nearest-neighbor search)
- transformers ≥ 4.0.0 (for Wav2Vec2 model access)
- numpy, scipy
- pandas, scikit-learn (for metrics and data manipulation)
- matplotlib, seaborn (for visualization)
- audiomentations (for data augmentation)
pip install tensorflow==2.18.0 tensorflow-tpu==2.18.0 \
faiss-cpu transformers pandas scikit-learn matplotlib seaborn audiomentations- Prepare Data: Ensure SEP-28k dataset is available in the
Dataset/directory - Open Notebook: Launch
stut-vecs/stut-vec11.ipynbin Jupyter or Kaggle - TPU/GPU Setup: Configure
strategyvariable for your hardware (TPU, GPU, or CPU) - Run All Cells: Execute cells sequentially to reproduce the full pipeline
The notebook will output:
- Trained model weights (
model_expert_a.keras,expert_b.keras,model_fusion_final.keras) - FAISS index (
train_data.index) - Visualization plots (t-SNE, per-class decomposition, training history)
- Classification reports and F1-score analysis
# pseudocode
# loss computes weighted contrastive signal based on label overlap
similarity_matrix = cosine(embeddings)
label_overlap = jaccard_similarity(labels)
# mask only pairs with overlapping labels contribute
mask = (label_overlap > 0) & (1 - identity_matrix)
# weighted log-prob loss
loss = -mean(mask * log_softmax(similarity / temperature))The fusion model learns a trust gate that adaptively weights the retrieval expert:
This allows the model to ignore retrieval evidence when audio signals are clear and leverage neighbors when audio is ambiguous.
- TPU Acceleration: This code is optimized for TPU v5-e8. For GPU, adjust
strategyand reduce batch sizes - Memory Management: Large batch sizes (512+) enable stable contrastive learning; use
tf.data.AUTOTUNE - Threshold Tuning: Micro F1-score grid search finds optimal thresholds per class (typically 0.3-0.5)
- Early Stopping: Monitor validation ROC-AUC; default patience is 5-20 epochs depending on stage
This work builds upon:
- Wav2Vec2 (Baevski et al., 2020) for acoustic feature extraction
- Conformer architecture (Gulati et al., 2021) for sequence modeling
- Supervised Contrastive Learning (Khosla et al., 2021) adapted for multi-label settings
- FAISS (Johnson et al., 2019) for efficient similarity search
If you use this work, please cite:
@article{stutterfuse2026,
title={StutterFuse: Mitigating Modality Collapse in Stuttering Detection with Jaccard-Weighted Metric Learning and Gated Fusion},
author={Singh, Guransh and Fahad, Md. Shah and Amritanjali},
journal={IEEE Transactions on Audio, Speech, and Language Processing},
year={2026}
}For questions, bug reports, or collaboration:
- Open an issue on GitHub
- Check out the main notebook:
stut-vecs/stut-vec11.ipynb