The most vital things to know are that the models can be found pretrained here on huggingface if you would not like to train them from scratch: https://huggingface.co/markoGeorgiaTech
Also please note that the notebook shows an error through the github UI due to size and format. Just download it or open it with Juypter/VScode/Collab to see its contents
NOTE: We cannot include the SENS-HEAD data due to privacy issues with its owners. It has been redacted from the repo.
Please email: mgjurevski3@gatech.edu or aszanti3@gatech.edu if there are any questions, comments, or concerns!
This project trains and evaluates text classification models using both a TF-IDF baseline and a DistilBERT model. The workflow is organized into scripts for data splitting, baseline training, transformer fine tuning, and error analysis.
Project Structure: data/ contains train.csv, val.csv, and test.csv. outputs/baseline/ stores the TF-IDF and logistic regression models. outputs/distilbert/ stores the fine tuned transformer model files.
Split the dataset Use a CSV that has at least two columns: text and y.
Example command: python split_data.py --infile my_data.csv --outdir data
This creates train.csv, val.csv, and test.csv using stratified sampling.
Train the baseline model This trains a TF-IDF plus logistic regression classifier.
Example command: python train_baseline.py --datadir data --outdir outputs/baseline
The script prints accuracy and F1 scores for both validation and test sets. It saves tfidf.joblib and lr.joblib.
Fine tune DistilBERT Example command: python train_transformer.py --datadir data --model distilbert-base-uncased --outdir outputs/distilbert --epochs 3 --bsz 16 --maxlen 192 --lr 2e-5
The model and tokenizer are saved in the output directory.
Analyze model errors This identifies the most confident false positives and false negatives produced by the transformer model.
Example command: python analyze_errors.py --csv data/test.csv --modeldir outputs/distilbert --k 50
Outputs: outputs/top_fp.csv outputs/top_fn.csv
Requirements: Python 3.8 or newer PyTorch transformers datasets scikit-learn pandas joblib
Dependencies can be installed with: pip install torch transformers datasets scikit-learn pandas joblib
GPU usage is automatic when available.