This repo contains implementations for training a reasoning-style model using Supervised Fine-Tuning (SFT) (LoRA) and Proximal Policy Optimization (PPO).
There are two main files:
- Single GPU SFT/PPO training - the Process Reward Model (PRM - ThinkPRM 1.5B parameters) and policy (Qwen2.5 7B) are on the same GPU.
- Multi-GPU SFT/PPO training - the PRM (larger ThinkPRM 7B) is deployed as a vLLM server and the policy (Qwen2.5 7B) is on another GPU making inference calls. This particular PRM doesn't support certain endpoints so it is somewhat experimental.
The above were tested using Nvidia A40 GPUs.
- SFT for the initial reasoning policy.
- Fine-tunes it on reasoning dataset (Countdown Maths task - make target number from 3 others) to initialise PPO training. We use a huggingface dataset.
- PPO training loop where both the policy model and PRM run on a single GPU.
- Suitable for smaller-scale experiments
- PPO training with 2 GPUs.
- The PRM is accessed remotely through a HTTP call to a vLLM server.
- Allows scaling to larger PRMs while keeping the policy model training isolated.
torch
transformers
datasets
tqdm
trl
peft
requests
vllm