This is the default final project for the Stanford CS 224N class. Please refer to the project handout on the course website for detailed instructions and an overview of the codebase.
This project comprises two parts. In the first part, you will implement some important components of the GPT-2 model to better understand its architecture. In the second part, you will use the token embeddings produced by your GPT-2 model on two downstream tasks: paraphrase detection and sonnet generation. You will implement extensions to improve your model's performance on these tasks.
In broad strokes, Part 1 of this project targets:
- modules/attention.py: Missing code blocks.
- modules/gpt2_layer.py: Missing code blocks.
- models/gpt2.py: Missing code blocks.
- classifier.py: Missing code blocks.
- optimizer.py: Missing code blocks.
To test Part 1, you will run:
optimizer_test.py: To test your implementation ofoptimizer.py.sanity_check.py: To test your implementation of GPT models.classifier.py: To perform sentiment classification using your models.
In Part 2 of this project, you will use GPT2 (via cloze-style classification) detect if one sentence is a paraphrase of another as well as generate sonnets via autoregressive language modeling.
To test Part 2, you will run:
paraphrase_detection.py: To perform paraphrase detection.sonnet_generation.py: To perform sonnet generation.
Important: Adjust training hyperparameters, particularly batch size, according to your GPU's specifications to optimize performance and prevent out-of-memory errors.
While there are missing code blocks that you need to implement in both of these files, the main focus of this second part are the extensions: how you modify your GPT2 model to improve its ability to determine if one sentence is a paraphrase of another as well as its ability to generate sonnets.
1. Create and activate the conda environment:
conda env create -f env.yml
conda activate cs224n_dfp2. (Optional) Install Modal for cloud GPU training:
pip install modal
modal setup # authenticates your accountpython paraphrase_detection.py --use_gpuKey arguments:
| Argument | Default | Description |
|---|---|---|
--epochs |
10 | Number of training epochs |
--lr |
1e-5 | Learning rate |
--batch_size |
16 | Batch size |
--model_size |
gpt2 |
One of gpt2, gpt2-medium, gpt2-large |
--use_loreft |
off | Use LoREFT parameter-efficient fine-tuning |
--loreft_rank |
4 | Rank of the LoREFT intervention subspace |
--loreft_window_size |
1 | Number of last tokens to apply LoREFT to |
Example with LoREFT on gpt2-medium:
python paraphrase_detection.py --use_gpu \
--model_size gpt2-medium \
--epochs 15 \
--lr 2e-4 \
--batch_size 128 \
--use_loreft \
--loreft_rank 32 \
--loreft_window_size 4modal run modal_run.pyOverride defaults from the command line:
modal run modal_run.py --epochs 20 --lr 2e-4 --use_loreft --loreft_rank 32Outputs (predictions and logs) are saved to the paraphrase-checkpoints Modal volume. To retrieve them:
modal volume ls paraphrase-checkpoints
modal volume get paraphrase-checkpoints <remote-path> <local-path>python sonnet_generation.py --use_gpuKey arguments:
| Argument | Default | Description |
|---|---|---|
--epochs |
10 | Number of training epochs |
--lr |
1e-5 | Learning rate |
--batch_size |
8 | Batch size |
--model_size |
gpt2 |
One of gpt2, gpt2-medium, gpt2-large, gpt2-xl |
--temperature |
1.2 | Sampling temperature for generation |
--top_p |
0.9 | Nucleus sampling cumulative probability |
--tuning_mode |
full |
full for full fine-tuning or loreft for LoREFT |
--loreft_rank |
4 | Rank of the LoREFT intervention subspace |
--loreft_dropout |
0.1 | Dropout on LoREFT intervention output |
--loreft_window_size |
1 | Number of last tokens to apply LoREFT to |
--log_curve |
off | Save per-epoch train loss and dev chrF to CSV |
Example with LoREFT on gpt2-medium:
python sonnet_generation.py --use_gpu \
--model_size gpt2-medium \
--epochs 20 \
--lr 1e-4 \
--tuning_mode loreft \
--loreft_rank 16 \
--temperature 1.2 \
--top_p 0.9 \
--log_curveGenerated sonnets are written to predictions/generated_sonnets.txt. Dev evaluation uses chrF score, logged per epoch.
This project is adapted from a prior year's CS 224N project Implement BERT .
Parts of the code are from the transformers
library (Apache License 2.0).