From 4d704a016bcd8330bc7551cad42d4b92129a53e7 Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Wed, 5 Nov 2025 19:34:03 +0000 Subject: [PATCH 01/20] Added CISPO loss function --- .../config/_generated_ppo_trainer.yaml | 2 + verl/trainer/config/actor/actor.yaml | 8 ++- verl/trainer/ppo/core_algos.py | 57 +++++++++++++++++++ 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 444595c7..fe94e7a4 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -56,6 +56,8 @@ actor_rollout_ref: clip_cov_ub: 5.0 kl_cov_ratio: 0.0002 ppo_kl_coef: 0.1 + cispo_clip_ratio_high: 0.2 + cispo_clip_ratio_low: 0.2 clip_ratio_c: 3.0 loss_agg_mode: token-mean entropy_coeff: 0 diff --git a/verl/trainer/config/actor/actor.yaml b/verl/trainer/config/actor/actor.yaml index 7c733ed6..e8c082e6 100644 --- a/verl/trainer/config/actor/actor.yaml +++ b/verl/trainer/config/actor/actor.yaml @@ -49,7 +49,7 @@ policy_loss: # # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs _target_: verl.workers.config.PolicyLossConfig - # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617 + # Loss function mode: vanilla / clip-cov / kl-cov / cispo (from https://arxiv.org/abs/2506.13585) / gpg (from https://arxiv.org/abs/2505.22617) loss_mode: "vanilla" # Ratio of tokens to be clipped for clip-cov loss @@ -67,6 +67,12 @@ policy_loss: # KL divergence penalty coefficient ppo_kl_coef: 0.1 + # Upper bound for CISPO importance ratio clipping + cispo_clip_ratio_high: 0.2 + + # Lower bound for CISPO importance ratio clipping + cispo_clip_ratio_low: 0.2 + # Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C clip_ratio_c: 3.0 diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 7a9103c4..5eb0b483 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -1163,6 +1163,63 @@ def compute_policy_loss_kl_cov( return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0) +@register_policy_loss("cispo") +def compute_policy_loss_cispo( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the CISPO policy objective and related metrics. + CISPO (Clipped Importance Sampling Policy Optimization) clips importance sampling weights + instead of dropping tokens, which is beneficial for training on sparse but critical tokens + and long-context reasoning in RL. + Reference: https://www.arxiv.org/abs/2506.13585 + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for loss computation + config (AlgoConfig): + Algorithm configuration containing CISPO parameters + Returns: + tuple: (pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower) + """ + # Setup CISPO configuration + assert config.policy_loss.loss_mode == "cispo", "CISPO loss mode not set in config" + cispo_clip_ratio_high = config.policy_loss.cispo_clip_ratio_high + cispo_clip_ratio_low = config.policy_loss.cispo_clip_ratio_low + clip_ratio_c = config.get("clip_ratio_c", 3.0) + + # Same code as compute_policy_loss + assert clip_ratio_c > 1.0, ( + "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + + f" but get the value: {clip_ratio_c}." + ) + + negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + # CISPO specific loss + ratio = ratio.detach() # Stop gradient on IS ratio + importance_sampling_weight = torch.clamp(ratio, min=1-cispo_clip_ratio_low, max=1+cispo_clip_ratio_high) + pg_losses = -advantages * log_prob * importance_sampling_weight + + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + return pg_loss, torch.tensor(0.0), ppo_kl, torch.tensor(0.0) # Not computing clip fractions for CISPO @register_policy_loss("geo_mean") def compute_policy_loss_geo_mean( From 5c448ad3c5736556474c4467bd859393dff1efaa Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Thu, 6 Nov 2025 00:34:31 +0000 Subject: [PATCH 02/20] Configurations and small debugging to get CISPO running --- scripts/train/test_k2p_cispo_m2.sh | 303 +++++++++++++++++++++++++++++ scripts/train/test_k2p_grpo_m2.sh | 303 +++++++++++++++++++++++++++++ verl/trainer/ppo/core_algos.py | 7 +- verl/workers/config/actor.py | 3 + 4 files changed, 615 insertions(+), 1 deletion(-) create mode 100644 scripts/train/test_k2p_cispo_m2.sh create mode 100644 scripts/train/test_k2p_grpo_m2.sh diff --git a/scripts/train/test_k2p_cispo_m2.sh b/scripts/train/test_k2p_cispo_m2.sh new file mode 100644 index 00000000..964c3db7 --- /dev/null +++ b/scripts/train/test_k2p_cispo_m2.sh @@ -0,0 +1,303 @@ +#!/bin/bash +#SBATCH --job-name=cispo-focused +#SBATCH --nodes=4 +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --account=iq +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=main + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://10.24.2.1:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/mnt/weka/home/taylor.killian/miniconda3/envs/sync-rl/bin/ +export NCCL_TIMEOUT_SECONDS=4800 +export TORCH_NCCL_ENABLE_MONITORING=0 +export NCCL_DEBUG=warn +export NCCL_NET=IB +export NCCL_IB_HCA="mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7" +export NCCL_CROSS_NIC=1 +export NCCL_IB_TC=136 +export NCCL_SOCKET_IFNAME="^lo,docker,virbr" +export CUDA_DEVICE_MAX_CONNECTIONS=8 +export NCCL_NVLS_ENABLE=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== +SHARED_DATA_PATH=/mnt/sharefs/users/zhuojun.cheng +SHARED_MODEL_PATH=/mnt/sharefs/users/haonan.li/models +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/guru_data/train/guru92k_release_0603 +TEST_DATA_DIR=${SHARED_DATA_PATH}/guru_data/test/online # ← unchanged + +# ---------- Math ---------- +# train +math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet +# test +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# ---------- Code ---------- +# train +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_1.3k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_440.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_7.5k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_8.8k.parquet +# test (unchanged) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500_sampled_200.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# ---------- Logic ---------- +# train +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.2k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_1.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_1.3k.parquet +# test (unchanged) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_300_sampled_200.parquet +graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet +arcagi1_test_path=${TEST_DATA_DIR}/simulation__arcagi1_200.parquet + +# ---------- Simulation ---------- +# train +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k_3.7k.parquet +# test (unchanged) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500_sampled_200.parquet + +# ---------- Table ---------- +# train +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet +# test (unchanged) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_300_sampled_200.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_300_sampled_200.parquet + +# ---------- Stem ---------- +# train +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet +# test (unchanged) +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet + +# Full Guru92k mixture +# train_files="['${math_train_path}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${graph_test_path}','${ordering_puzzle_test_path}','${arcagi1_test_path}','${codeio_test_path}','${multihier_test_path}','${hitab_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Focused Guru92k mixture (Math + Code + STEM) +train_files="['${math_train_path}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}']" +test_files="['${math_test_path}','${aime_test_path}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# =================== Model =================== +BASE_MODEL=Qwen/Qwen2.5-7B + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${BASE_MODEL##*/}-${SLURM_JOB_ID} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 32)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_mode="cispo" # Default is "vanilla" which is equivalent to PPO; +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=2 +gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.cispo_clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.policy_loss.cispo_clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.dtype=${rollout_dtype} \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=1 \ No newline at end of file diff --git a/scripts/train/test_k2p_grpo_m2.sh b/scripts/train/test_k2p_grpo_m2.sh new file mode 100644 index 00000000..f5d6df24 --- /dev/null +++ b/scripts/train/test_k2p_grpo_m2.sh @@ -0,0 +1,303 @@ +#!/bin/bash +#SBATCH --job-name=grpo-focused +#SBATCH --nodes=4 +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --account=iq +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=main + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://10.24.2.1:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/mnt/weka/home/taylor.killian/miniconda3/envs/sync-rl/bin/ +export NCCL_TIMEOUT_SECONDS=4800 +export TORCH_NCCL_ENABLE_MONITORING=0 +export NCCL_DEBUG=warn +export NCCL_NET=IB +export NCCL_IB_HCA="mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7" +export NCCL_CROSS_NIC=1 +export NCCL_IB_TC=136 +export NCCL_SOCKET_IFNAME="^lo,docker,virbr" +export CUDA_DEVICE_MAX_CONNECTIONS=8 +export NCCL_NVLS_ENABLE=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== +SHARED_DATA_PATH=/mnt/sharefs/users/zhuojun.cheng +SHARED_MODEL_PATH=/mnt/sharefs/users/haonan.li/models +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/guru_data/train/guru92k_release_0603 +TEST_DATA_DIR=${SHARED_DATA_PATH}/guru_data/test/online # ← unchanged + +# ---------- Math ---------- +# train +math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet +# test +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# ---------- Code ---------- +# train +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_1.3k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_440.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_7.5k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_8.8k.parquet +# test (unchanged) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500_sampled_200.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# ---------- Logic ---------- +# train +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.2k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_1.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_1.3k.parquet +# test (unchanged) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_300_sampled_200.parquet +graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet +arcagi1_test_path=${TEST_DATA_DIR}/simulation__arcagi1_200.parquet + +# ---------- Simulation ---------- +# train +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k_3.7k.parquet +# test (unchanged) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500_sampled_200.parquet + +# ---------- Table ---------- +# train +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet +# test (unchanged) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_300_sampled_200.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_300_sampled_200.parquet + +# ---------- Stem ---------- +# train +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet +# test (unchanged) +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet + +# Full Guru92k mixture +# train_files="['${math_train_path}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${graph_test_path}','${ordering_puzzle_test_path}','${arcagi1_test_path}','${codeio_test_path}','${multihier_test_path}','${hitab_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Focused Guru92k mixture (Math + Code + STEM) +train_files="['${math_train_path}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}']" +test_files="['${math_test_path}','${aime_test_path}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# =================== Model =================== +BASE_MODEL=Qwen/Qwen2.5-7B + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${BASE_MODEL##*/}-${SLURM_JOB_ID} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 32)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_mode="vanilla" # Default is "vanilla" which is equivalent to PPO; +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=2 +gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.cispo_clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.policy_loss.cispo_clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.dtype=${rollout_dtype} \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=1 \ No newline at end of file diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 5eb0b483..82901040 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -1163,6 +1163,7 @@ def compute_policy_loss_kl_cov( return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0) + @register_policy_loss("cispo") def compute_policy_loss_cispo( old_log_prob: torch.Tensor, @@ -1218,8 +1219,12 @@ def compute_policy_loss_cispo( pg_losses = -advantages * log_prob * importance_sampling_weight pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + # For compatibility, return zero for pg_clipfrac_lower and pg_clipfrac (not used in CISPO) + pg_clipfrac = torch.tensor(0.0, device=pg_loss.device) + pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device) + + return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower - return pg_loss, torch.tensor(0.0), ppo_kl, torch.tensor(0.0) # Not computing clip fractions for CISPO @register_policy_loss("geo_mean") def compute_policy_loss_geo_mean( diff --git a/verl/workers/config/actor.py b/verl/workers/config/actor.py index af619973..011d2ebc 100644 --- a/verl/workers/config/actor.py +++ b/verl/workers/config/actor.py @@ -49,6 +49,8 @@ class PolicyLossConfig(BaseConfig): clip_cov_ub: float = 5.0 kl_cov_ratio: float = 0.0002 ppo_kl_coef: float = 0.1 + cispo_clip_ratio_high: float = 0.2 + cispo_clip_ratio_low: float = 0.2 @dataclass @@ -225,6 +227,7 @@ class FSDPActorConfig(ActorConfig): """ strategy: str = "fsdp" + dtype: str = "bfloat16" grad_clip: float = 1.0 ulysses_sequence_parallel_size: int = 1 entropy_from_logits_with_chunking: bool = False From 015a6b85b3b65023171991c0feb6ce1bba18cc70 Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Thu, 6 Nov 2025 17:46:45 +0000 Subject: [PATCH 03/20] Completing verl-expected policy loss signature --- scripts/train/test_k2p_cispo_m2.sh | 1 + scripts/train/test_k2p_grpo_m2.sh | 1 + verl/trainer/ppo/core_algos.py | 9 +++++++++ 3 files changed, 11 insertions(+) diff --git a/scripts/train/test_k2p_cispo_m2.sh b/scripts/train/test_k2p_cispo_m2.sh index 964c3db7..abf6e6fd 100644 --- a/scripts/train/test_k2p_cispo_m2.sh +++ b/scripts/train/test_k2p_cispo_m2.sh @@ -280,6 +280,7 @@ offload=True actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.rollout.multi_turn.enable=False \ actor_rollout_ref.rollout.mode="sync" \ + actor_rollout_ref.rollout.dtype=${rollout_dtype} \ +actor_rollout_ref.model.override_config.attention_dropout=0. \ +actor_rollout_ref.model.override_config.embd_pdrop=0. \ +actor_rollout_ref.model.override_config.resid_pdrop=0. \ diff --git a/scripts/train/test_k2p_grpo_m2.sh b/scripts/train/test_k2p_grpo_m2.sh index f5d6df24..0f93b728 100644 --- a/scripts/train/test_k2p_grpo_m2.sh +++ b/scripts/train/test_k2p_grpo_m2.sh @@ -280,6 +280,7 @@ offload=True actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.rollout.multi_turn.enable=False \ actor_rollout_ref.rollout.mode="sync" \ + actor_rollout_ref.rollout.dtype=${rollout_dtype} \ +actor_rollout_ref.model.override_config.attention_dropout=0. \ +actor_rollout_ref.model.override_config.embd_pdrop=0. \ +actor_rollout_ref.model.override_config.resid_pdrop=0. \ diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 82901040..d78ada89 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -1172,6 +1172,7 @@ def compute_policy_loss_cispo( response_mask: torch.Tensor, loss_agg_mode: str = "token-mean", config: Optional[DictConfig | AlgoConfig] = None, + rollout_log_probs: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute the CISPO policy objective and related metrics. @@ -1192,6 +1193,8 @@ def compute_policy_loss_cispo( Aggregation mode for loss computation config (AlgoConfig): Algorithm configuration containing CISPO parameters + rollout_log_probs: `(torch.Tensor)`: + log probabilities of actions under the rollout policy, shape (batch_size, response_length). Returns: tuple: (pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower) """ @@ -1218,6 +1221,12 @@ def compute_policy_loss_cispo( importance_sampling_weight = torch.clamp(ratio, min=1-cispo_clip_ratio_low, max=1+cispo_clip_ratio_high) pg_losses = -advantages * log_prob * importance_sampling_weight + if config.tis_imp_ratio_cap > 0 and rollout_log_probs is not None: + # Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl + tis_imp_ratio = torch.exp(old_log_prob - rollout_log_probs) + tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap) + pg_losses = pg_losses * tis_imp_ratio + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) # For compatibility, return zero for pg_clipfrac_lower and pg_clipfrac (not used in CISPO) pg_clipfrac = torch.tensor(0.0, device=pg_loss.device) From 154135e3da83816fb63dda55cc8705dcb034fc55 Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Sat, 8 Nov 2025 00:25:14 +0000 Subject: [PATCH 04/20] Update to rull full on-policy --- scripts/train/test_k2p_cispo_m2.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/train/test_k2p_cispo_m2.sh b/scripts/train/test_k2p_cispo_m2.sh index abf6e6fd..a3e816c0 100644 --- a/scripts/train/test_k2p_cispo_m2.sh +++ b/scripts/train/test_k2p_cispo_m2.sh @@ -1,5 +1,5 @@ #!/bin/bash -#SBATCH --job-name=cispo-focused +#SBATCH --job-name=cispo-focused-onpolicy-bsz256 #SBATCH --nodes=4 #SBATCH --ntasks=4 #SBATCH --ntasks-per-node=1 @@ -166,7 +166,7 @@ use_kl_loss=False kl_loss_coef=0.0 clip_ratio_low=0.2 -clip_ratio_high=0.2 +clip_ratio_high=0.28 max_prompt_length=$((1024 * 4)) max_response_length=$((1024 * 32)) @@ -184,7 +184,7 @@ max_num_gen_batches=10 train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n gen_prompt_bsz=$((train_prompt_bsz * 1)) n_resp_per_prompt=16 -train_prompt_mini_bsz=32 # model grad update batchsize +train_prompt_mini_bsz=256 # model grad update batchsize # Algorithm temperature=1.0 From fd9fd42d8e63ae03d8dab13f5109cbbd830a7738 Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Sat, 8 Nov 2025 06:52:47 +0000 Subject: [PATCH 05/20] Fixing CISPO IS clipping to be one-sided as intended --- scripts/train/test_k2p_cispo_m2.sh | 4 ++-- scripts/train/test_k2p_grpo_m2.sh | 2 +- verl/trainer/ppo/core_algos.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/train/test_k2p_cispo_m2.sh b/scripts/train/test_k2p_cispo_m2.sh index a3e816c0..3b73c65d 100644 --- a/scripts/train/test_k2p_cispo_m2.sh +++ b/scripts/train/test_k2p_cispo_m2.sh @@ -15,7 +15,7 @@ # =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +RESUME_CKPT_DIR_NAME="cispo-focused-onpolicy-bsz256-Qwen2.5-7B-998181" # Fill in the checkpoint directory name to resume from, otherwise from scratch export STEM_LLM_JUDGE_URL="http://10.24.2.1:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain # =================== Cluster Environment =================== @@ -166,7 +166,7 @@ use_kl_loss=False kl_loss_coef=0.0 clip_ratio_low=0.2 -clip_ratio_high=0.28 +clip_ratio_high=2.0 max_prompt_length=$((1024 * 4)) max_response_length=$((1024 * 32)) diff --git a/scripts/train/test_k2p_grpo_m2.sh b/scripts/train/test_k2p_grpo_m2.sh index 0f93b728..aefbf34f 100644 --- a/scripts/train/test_k2p_grpo_m2.sh +++ b/scripts/train/test_k2p_grpo_m2.sh @@ -15,7 +15,7 @@ # =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +RESUME_CKPT_DIR_NAME="grpo-focused-Qwen2.5-7B-994853" # Fill in the checkpoint directory name to resume from, otherwise from scratch export STEM_LLM_JUDGE_URL="http://10.24.2.1:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain # =================== Cluster Environment =================== diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index d78ada89..425f095b 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -1218,7 +1218,7 @@ def compute_policy_loss_cispo( # CISPO specific loss ratio = ratio.detach() # Stop gradient on IS ratio - importance_sampling_weight = torch.clamp(ratio, min=1-cispo_clip_ratio_low, max=1+cispo_clip_ratio_high) + importance_sampling_weight = torch.clamp(ratio, max=1+cispo_clip_ratio_high) pg_losses = -advantages * log_prob * importance_sampling_weight if config.tis_imp_ratio_cap > 0 and rollout_log_probs is not None: From 73d2b95951bf5ac545b3ada737aafc0fa0d95441 Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Fri, 14 Nov 2025 07:32:55 +0000 Subject: [PATCH 06/20] Sync --- scripts/train/test_k2p_cispo_m2.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/train/test_k2p_cispo_m2.sh b/scripts/train/test_k2p_cispo_m2.sh index 3b73c65d..b4a49796 100644 --- a/scripts/train/test_k2p_cispo_m2.sh +++ b/scripts/train/test_k2p_cispo_m2.sh @@ -1,5 +1,5 @@ #!/bin/bash -#SBATCH --job-name=cispo-focused-onpolicy-bsz256 +#SBATCH --job-name=cispo-focused-fixed #SBATCH --nodes=4 #SBATCH --ntasks=4 #SBATCH --ntasks-per-node=1 @@ -15,7 +15,7 @@ # =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="cispo-focused-onpolicy-bsz256-Qwen2.5-7B-998181" # Fill in the checkpoint directory name to resume from, otherwise from scratch +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch export STEM_LLM_JUDGE_URL="http://10.24.2.1:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain # =================== Cluster Environment =================== @@ -184,7 +184,7 @@ max_num_gen_batches=10 train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n gen_prompt_bsz=$((train_prompt_bsz * 1)) n_resp_per_prompt=16 -train_prompt_mini_bsz=256 # model grad update batchsize +train_prompt_mini_bsz=32 # model grad update batchsize # Algorithm temperature=1.0 From 99c7c65203edebe59b9e2e4c4d2abdbba89be6c6 Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Fri, 14 Nov 2025 08:16:17 +0000 Subject: [PATCH 07/20] sync --- scripts/train/test_k2p_cispo_m2.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train/test_k2p_cispo_m2.sh b/scripts/train/test_k2p_cispo_m2.sh index b4a49796..01cf6ae8 100644 --- a/scripts/train/test_k2p_cispo_m2.sh +++ b/scripts/train/test_k2p_cispo_m2.sh @@ -15,7 +15,7 @@ # =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +RESUME_CKPT_DIR_NAME="cispo-focused-fixed-Qwen2.5-7B-1016971" # Fill in the checkpoint directory name to resume from, otherwise from scratch export STEM_LLM_JUDGE_URL="http://10.24.2.1:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain # =================== Cluster Environment =================== From 90190747863ee3e11de6c05a6624bfdeaf767495 Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Thu, 20 Nov 2025 23:41:19 +0000 Subject: [PATCH 08/20] Sync --- recipe/dapo/dapo_ray_trainer.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py index cffb342e..2e144130 100644 --- a/recipe/dapo/dapo_ray_trainer.py +++ b/recipe/dapo/dapo_ray_trainer.py @@ -251,9 +251,10 @@ def fit(self): batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) prompt_bsz = self.config.data.train_batch_size - if num_prompt_in_batch < prompt_bsz: + max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + if num_prompt_in_batch < prompt_bsz and max_num_gen_batches > 1: # Added by Reasoning360 TWK NOTE: second condition is to account for when we have zero-variance filtering but are not dynamically growing the batch... print(f"{num_prompt_in_batch=} < {prompt_bsz=}") - max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + # max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: print(f"{num_gen_batches=}. Keep generating...") progress_bar.update(1) @@ -267,9 +268,14 @@ def fit(self): + " You could also try set max_num_gen_batches=0 to enable endless trials." ) else: - # Align the batch - traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n - batch = batch[:traj_bsz] + # Added by Reasoning360, need to account for when our batch is smaller due to zero-variance filtering + if num_prompt_in_batch >= prompt_bsz: + # Align the batch + traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + batch = batch[:traj_bsz] + else: + # TWK TODO!!!: RESCALE THIS SO THAT THE BATCH*N IS DIVISIBLE BY k_partitions (n_gpus...) + print(f"Final {num_prompt_in_batch=} < {prompt_bsz=} after {num_gen_batches=} generation batches. Proceeding with smaller batch...") # === Updating === From 312ed82c57d5a953a09104c4eb815511ff02d6ee Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Wed, 10 Dec 2025 05:41:51 +0000 Subject: [PATCH 09/20] Sync with Latex and sympy commented out from naive_dapo.py --- scripts/train/k2p_hero_cispo.sh | 326 +++++++++++++++++ scripts/train/k2p_hero_grpo.sh | 328 ++++++++++++++++++ verl/utils/reward_score/naive_dapo.py | 6 +- .../utils/reward_score/prime_math/__init__.py | 2 +- 4 files changed, 658 insertions(+), 4 deletions(-) create mode 100644 scripts/train/k2p_hero_cispo.sh create mode 100644 scripts/train/k2p_hero_grpo.sh diff --git a/scripts/train/k2p_hero_cispo.sh b/scripts/train/k2p_hero_cispo.sh new file mode 100644 index 00000000..f37efe39 --- /dev/null +++ b/scripts/train/k2p_hero_cispo.sh @@ -0,0 +1,326 @@ +#!/bin/bash +#SBATCH --job-name=cispo-focused-k2p-finalInstruct-temp1.0-wOmni-fix2 +#SBATCH --nodes=32 +#SBATCH --ntasks=32 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=main +#SBATCH --exclude=azure-uk-hpc-H200-instance-114,azure-uk-hpc-H200-instance-394 + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-009:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800000 +export OMPI_MCA_coll_hcoll_enable=0 \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== +SHARED_DATA_PATH=/lustrefs/users/haonan.li/data/k2 +MATH_DATA_PATH=/lustrefs/users/zhuojun.cheng/vpim/guru_data/train/postprocessed_dedup_am_semantic_filtered_0.05_0.94_thresh_ratio0.5_sample1.0_balanced_step2 +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_1_datamix_6 +TEST_DATA_DIR=${SHARED_DATA_PATH}/test_12k_len + +# Math (train) +math_train_path1=${MATH_DATA_PATH}/math__combined_118.2k.part1_scored.parquet +math_train_path2=${MATH_DATA_PATH}/math__combined_118.2k.part2_scored.parquet +# math_train_path1=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +# math_train_path2=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoninggym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet +# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet + + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +# codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500_sampled_200.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# Instruction follow (train) +if_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet + +if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet + +# Focused data mixture (math, code, stem) +train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Full data mixture (uncomment to use) +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${reasoninggym_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}','${nemotron_train_path}','${if_train_path}']" # '${synlogic_train_path}', +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # '${synlogic_test_path}', + + +# =================== Model =================== +# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT +BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=2.0 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 32)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 12)) +overlong_penalty_factor=1.0 + +loss_mode="cispo" # Default is 'vanilla' which is equivalent to PPO; +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=256 # model grad update batchsize + +# Algorithm +temperature=1.0 +val_temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=16 # Reduced from 32 to reduce memory pressure +gen_tp=4 +gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=True \ + actor_rollout_ref.model.use_liger=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 \ No newline at end of file diff --git a/scripts/train/k2p_hero_grpo.sh b/scripts/train/k2p_hero_grpo.sh new file mode 100644 index 00000000..0168d29e --- /dev/null +++ b/scripts/train/k2p_hero_grpo.sh @@ -0,0 +1,328 @@ +#!/bin/bash +#SBATCH --job-name=grpo-k2p-32k264k-stage2-fromStep200-focused +#SBATCH --nodes=64 +#SBATCH --ntasks=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=main +#SBATCH --exclude=azure-uk-hpc-H200-instance-114,azure-uk-hpc-H200-instance-394 + +# SBATCH --job-name=grpo-hero-k2p-finalInstruct-temp1.2-wOmni-fix2 + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="grpo-k2p-32k264k-stage2-focused-404083" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-009:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800000 +export OMPI_MCA_coll_hcoll_enable=0 \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== +SHARED_DATA_PATH=/lustrefs/users/haonan.li/data/k2 +MATH_DATA_PATH=/lustrefs/users/zhuojun.cheng/vpim/guru_data/train/postprocessed_dedup_am_semantic_filtered_0.05_0.94_thresh_ratio0.5_sample1.0_balanced_step2 +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_1_datamix_6 +TEST_DATA_DIR=${SHARED_DATA_PATH}/test_12k_len + +# Math (train) +math_train_path1=${MATH_DATA_PATH}/math__combined_118.2k.part1_scored.parquet +math_train_path2=${MATH_DATA_PATH}/math__combined_118.2k.part2_scored.parquet +# math_train_path1=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +# math_train_path2=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoninggym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet +# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet + + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +# codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500_sampled_200.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# Instruction follow (train) +if_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet + +if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet + +# Focused data mixture (math, code, stem) +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Full data mixture (uncomment to use) +train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${synlogic_train_path}','${reasoninggym_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}','${nemotron_train_path}','${if_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # '${synlogic_test_path}', + + +# =================== Model =================== +# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT +BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k +# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-focused-k2p-finalInstruct-temp1.2-wOmni-fix2-403906/global_step_150/actor/huggingface + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 64)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 12)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=256 # model grad update batchsize + +# Algorithm +temperature=1.2 +val_temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=16 # Reduced from 32 to reduce memory pressure +gen_tp=4 +gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=True \ + actor_rollout_ref.model.use_liger=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 \ No newline at end of file diff --git a/verl/utils/reward_score/naive_dapo.py b/verl/utils/reward_score/naive_dapo.py index d26a1dd7..102d714e 100644 --- a/verl/utils/reward_score/naive_dapo.py +++ b/verl/utils/reward_score/naive_dapo.py @@ -189,7 +189,7 @@ def _parse_latex(expr: str) -> str: expr = expr.replace("\\tfrac", "\\frac") expr = expr.replace("\\dfrac", "\\frac") expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. - expr = latex2text.LatexNodes2Text().latex_to_text(expr) + # expr = latex2text.LatexNodes2Text().latex_to_text(expr) # Replace the specific characters that this parser uses. expr = expr.replace("√", "sqrt") @@ -425,8 +425,8 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) is_correct = False - else: - is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + # else: + # is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) if not is_correct: break diff --git a/verl/utils/reward_score/prime_math/__init__.py b/verl/utils/reward_score/prime_math/__init__.py index 8d9d273e..d0ea47ac 100644 --- a/verl/utils/reward_score/prime_math/__init__.py +++ b/verl/utils/reward_score/prime_math/__init__.py @@ -55,7 +55,7 @@ def _parse_latex(expr: str) -> str: expr = expr.replace("\\tfrac", "\\frac") expr = expr.replace("\\dfrac", "\\frac") expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. - expr = latex2text.LatexNodes2Text().latex_to_text(expr) + # expr = latex2text.LatexNodes2Text().latex_to_text(expr) # Replace the specific characters that this parser uses. expr = expr.replace("√", "sqrt") From 61cb565be8a428c4a090132d29c76c95a5395881 Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Thu, 11 Dec 2025 05:11:19 +0000 Subject: [PATCH 10/20] Sync --- scripts/train/k2p_hero_grpo.sh | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/scripts/train/k2p_hero_grpo.sh b/scripts/train/k2p_hero_grpo.sh index 0168d29e..b0c04dba 100644 --- a/scripts/train/k2p_hero_grpo.sh +++ b/scripts/train/k2p_hero_grpo.sh @@ -1,7 +1,7 @@ #!/bin/bash -#SBATCH --job-name=grpo-k2p-32k264k-stage2-fromStep200-focused -#SBATCH --nodes=64 -#SBATCH --ntasks=64 +#SBATCH --job-name=grpo-focused-k2p-finalInstruct-temp1.2-wOmni-fix2 +#SBATCH --nodes=32 +#SBATCH --ntasks=32 #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:8 #SBATCH --cpus-per-task=96 @@ -16,7 +16,7 @@ # SBATCH --job-name=grpo-hero-k2p-finalInstruct-temp1.2-wOmni-fix2 # =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="grpo-k2p-32k264k-stage2-focused-404083" # Fill in the checkpoint directory name to resume from, otherwise from scratch +RESUME_CKPT_DIR_NAME="grpo-focused-k2p-finalInstruct-temp1.2-wOmni-fix2-403906" # Fill in the checkpoint directory name to resume from, otherwise from scratch export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-009:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain # =================== Cluster Environment =================== @@ -130,11 +130,11 @@ if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet # Focused data mixture (math, code, stem) -# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" # Full data mixture (uncomment to use) -train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${synlogic_train_path}','${reasoninggym_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}','${nemotron_train_path}','${if_train_path}']" +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${reasoninggym_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}','${nemotron_train_path}','${if_train_path}']" # '${synlogic_train_path}', # test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # '${synlogic_test_path}', @@ -196,7 +196,7 @@ clip_ratio_low=0.2 clip_ratio_high=0.28 max_prompt_length=$((1024 * 4)) -max_response_length=$((1024 * 64)) +max_response_length=$((1024 * 32)) enable_overlong_buffer=False overlong_buffer_len=$((1024 * 12)) overlong_penalty_factor=1.0 From dc631bdaa35e4216f36a0e258f768c2170f88271 Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Tue, 16 Dec 2025 00:24:46 +0000 Subject: [PATCH 11/20] Updating rewards to better handle OmniMath scoring --- scripts/train/k2p_hero_grpo.sh | 14 +- scripts/train/k2p_hero_grpo_newData.sh | 343 ++++++++++++++++++ .../reward_score/math_llm_judge/__init__.py | 12 +- verl/utils/reward_score/naive_dapo.py | 5 +- 4 files changed, 359 insertions(+), 15 deletions(-) create mode 100644 scripts/train/k2p_hero_grpo_newData.sh diff --git a/scripts/train/k2p_hero_grpo.sh b/scripts/train/k2p_hero_grpo.sh index b0c04dba..e0ec0988 100644 --- a/scripts/train/k2p_hero_grpo.sh +++ b/scripts/train/k2p_hero_grpo.sh @@ -1,7 +1,7 @@ #!/bin/bash -#SBATCH --job-name=grpo-focused-k2p-finalInstruct-temp1.2-wOmni-fix2 -#SBATCH --nodes=32 -#SBATCH --ntasks=32 +#SBATCH --job-name=grpo-k2p-finalInstruct-64k-temp1.2-focused +#SBATCH --nodes=64 +#SBATCH --ntasks=64 #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:8 #SBATCH --cpus-per-task=96 @@ -16,7 +16,7 @@ # SBATCH --job-name=grpo-hero-k2p-finalInstruct-temp1.2-wOmni-fix2 # =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="grpo-focused-k2p-finalInstruct-temp1.2-wOmni-fix2-403906" # Fill in the checkpoint directory name to resume from, otherwise from scratch +RESUME_CKPT_DIR_NAME="grpo-k2p-finalInstruct-64k-temp1.2-focused-404084" # Fill in the checkpoint directory name to resume from, otherwise from scratch export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-009:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain # =================== Cluster Environment =================== @@ -142,7 +142,7 @@ test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${a # BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) # BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k -# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-focused-k2p-finalInstruct-temp1.2-wOmni-fix2-403906/global_step_150/actor/huggingface +# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-focused-k2p-finalInstruct-temp1.2-wOmni-fix2-403906/global_step_300/actor/huggingface # =================== Logging =================== WANDB_PROJECT=k2plus_rl @@ -196,7 +196,7 @@ clip_ratio_low=0.2 clip_ratio_high=0.28 max_prompt_length=$((1024 * 4)) -max_response_length=$((1024 * 32)) +max_response_length=$((1024 * 64)) enable_overlong_buffer=False overlong_buffer_len=$((1024 * 12)) overlong_penalty_factor=1.0 @@ -317,7 +317,7 @@ offload=True trainer.logger=['console','wandb'] \ trainer.project_name=${WANDB_PROJECT} \ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ - trainer.val_before_train=True \ + trainer.val_before_train=False \ trainer.n_gpus_per_node=8 \ trainer.nnodes=$worker_num \ trainer.save_freq=10 \ diff --git a/scripts/train/k2p_hero_grpo_newData.sh b/scripts/train/k2p_hero_grpo_newData.sh new file mode 100644 index 00000000..ade66887 --- /dev/null +++ b/scripts/train/k2p_hero_grpo_newData.sh @@ -0,0 +1,343 @@ +#!/bin/bash +#SBATCH --job-name=grpo-k2p-newFiltered-64k-fullData-finalInstruct +#SBATCH --nodes=64 +#SBATCH --ntasks=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=main +#SBATCH --exclude=azure-uk-hpc-H200-instance-114,azure-uk-hpc-H200-instance-394 + +# SBATCH --job-name=grpo-hero-k2p-finalInstruct-temp1.2-wOmni-fix2 + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-009:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-033:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800000 +export OMPI_MCA_coll_hcoll_enable=0 \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== + +# Training Data Configuration +DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_1" +train_file_list=() + +# List of datasets to include (filename only) +# Comment out lines to exclude specific datasets +dataset_names=( + "codegen__deduped_leetcode2k_2.4k.parquet" + "codegen__deduped_livecodebench_599.parquet" + "codegen__deduped_primeintellect_9.6k.parquet" + "codegen__deduped_taco_11.1k.parquet" + "ifbench__fixed_85.6k.parquet" + "logic__arcagi1_297.parquet" + "logic__arcagi2_653.parquet" + "logic__barc_3.4k.parquet" + "logic__graph_logical_dataset_1.4k.parquet" + "logic__ordering_puzzle_dataset_2.9k.parquet" + "logic__reasoning_gym_40.6k.parquet" + "logic__synlogic_12.1k.parquet" + "logic__zebra_puzzle_dataset_5.0k.parquet" + "math__combined_118.2k.part1.parquet" + "math__combined_118.2k.part2.parquet" + "omni_math_4.43k_dedup.parquet" + "simulation__codeio_fixed_12.1k.parquet" + "stem__nemotron_13.3k.parquet" + "stem__web_31.7k.parquet" + "table__hitab_7.4k.parquet" + "table__multihier_2.9k.parquet" +) + +echo "Collecting training files from ${DATA_MIX_DIR}..." + +# Search for each dataset in all subdirectories +for dataset in "${dataset_names[@]}"; do + for subdir in "impossible_questions" "131k_context_questions" "main_questions"; do + file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" + if [ -f "$file_path" ]; then + echo "Adding: $file_path" + train_file_list+=("'$file_path'") + fi + done +done + +# Join with comma to form Python list string +IFS=, +train_files="[${train_file_list[*]}]" +unset IFS + +echo "Total training files found: ${#train_file_list[@]}" + +# Test Data Configuration +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet +# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet + +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# Instruction follow (test) +if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet + +# Focused data mixture (math, code, stem) +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Full data mixture (uncomment to use) +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${synlogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # + + +# =================== Model =================== +# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT +BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k +# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-focused-k2p-finalInstruct-temp1.2-wOmni-fix2-403906/global_step_300/actor/huggingface + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 64)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 12)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=256 # model grad update batchsize + +# Algorithm +temperature=1.2 +val_temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=16 # Reduced from 32 to reduce memory pressure +gen_tp=4 +gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=True \ + actor_rollout_ref.model.use_liger=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 \ No newline at end of file diff --git a/verl/utils/reward_score/math_llm_judge/__init__.py b/verl/utils/reward_score/math_llm_judge/__init__.py index 7d2b20ca..3d78e1c6 100644 --- a/verl/utils/reward_score/math_llm_judge/__init__.py +++ b/verl/utils/reward_score/math_llm_judge/__init__.py @@ -398,14 +398,15 @@ def llm_check_answer(model_output: str, ground_truth: str, question: str) -> boo # use llm to check if the answer is correct # url = "http://176.56.200.81:30000/v1/chat/completions" - url = os.getenv("MATH_LLM_JUDGE_URL") - if not url: + url_base = os.getenv("MATH_LLM_JUDGE_URL") + if not url_base: raise ValueError("MATH_LLM_JUDGE_URL is not set") + url = url_base.rstrip("/") + "/v1/chat/completions" prompt = input_template.format(QUESTION=question, STUDENT_ANSWER=model_output, REFERENCE_ANSWER=ground_truth) data = { - "model": "Qwen/Qwen2.5-32B-Instruct", + "model": "openai/gpt-oss-120b", "messages": [{"role": "user", "content": prompt}], } response = requests.post(url, json=data) @@ -423,7 +424,7 @@ def llm_check_answer(model_output: str, ground_truth: str, question: str) -> boo def compute_score(model_output: str, ground_truth: str, extra_info: dict) -> bool: - question = extra_info["question"] + question = extra_info["original_question"] model_output = str(model_output) ground_truth = str(ground_truth) @@ -447,5 +448,4 @@ def compute_score(model_output: str, if is_matched and not is_correct: # use llm to check if the answer is correct is_correct = llm_check_answer(extracted_model_output, ground_truth, question) - - return is_correct, 1, extracted_model_output + return is_correct, 1, extracted_model_output \ No newline at end of file diff --git a/verl/utils/reward_score/naive_dapo.py b/verl/utils/reward_score/naive_dapo.py index 102d714e..f819048a 100644 --- a/verl/utils/reward_score/naive_dapo.py +++ b/verl/utils/reward_score/naive_dapo.py @@ -425,8 +425,9 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) is_correct = False - # else: - # is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + else: + # is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + is_correct = False if not is_correct: break From b40564781093279bec49169a3ce2ba6e558c4065 Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Tue, 6 Jan 2026 19:11:08 +0000 Subject: [PATCH 12/20] Sync --- scripts/train/k2p_hero_cispo.sh | 1 + scripts/train/k2p_hero_cispo_newData.sh | 362 +++++++++++++++++++++ scripts/train/k2p_hero_grpo.sh | 6 +- scripts/train/k2p_hero_grpo_newData.sh | 58 ++-- scripts/train/k2p_hero_grpo_newNewData.sh | 371 ++++++++++++++++++++++ verl/trainer/config/data/legacy_data.yaml | 3 + verl/trainer/ppo/ray_trainer.py | 168 +++++++++- 7 files changed, 939 insertions(+), 30 deletions(-) create mode 100644 scripts/train/k2p_hero_cispo_newData.sh create mode 100644 scripts/train/k2p_hero_grpo_newNewData.sh diff --git a/scripts/train/k2p_hero_cispo.sh b/scripts/train/k2p_hero_cispo.sh index f37efe39..25e4b338 100644 --- a/scripts/train/k2p_hero_cispo.sh +++ b/scripts/train/k2p_hero_cispo.sh @@ -255,6 +255,7 @@ offload=True actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ actor_rollout_ref.actor.optim.lr=5e-7 \ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ actor_rollout_ref.actor.optim.weight_decay=0.1 \ diff --git a/scripts/train/k2p_hero_cispo_newData.sh b/scripts/train/k2p_hero_cispo_newData.sh new file mode 100644 index 00000000..8ffd6d2f --- /dev/null +++ b/scripts/train/k2p_hero_cispo_newData.sh @@ -0,0 +1,362 @@ +#!/bin/bash +#SBATCH --job-name=cispo-k2p-newFiltered-64k-mainQs-finalInstruct +#SBATCH --nodes=64 +#SBATCH --ntasks=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=higherprio +#SBATCH --exclude=azure-uk-hpc-H200-instance-[347-410] + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="cispo-k2p-newFiltered-64k-mainQs-finalInstruct-406746" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-286:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800000 +export OMPI_MCA_coll_hcoll_enable=0 \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +NCCL_SOCKET_IFNAME=eth0 \ +TP_SOCKET_IFNAME=eth0 \ +GLOO_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== + +# Training Data Configuration +DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_1" +train_file_list=() + +iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet" + +# List of datasets to include (filename only) +# Comment out lines to exclude specific datasets +dataset_names=( + "codegen__deduped_leetcode2k_2.4k.parquet" + "codegen__deduped_livecodebench_599.parquet" + "codegen__deduped_primeintellect_9.6k.parquet" + "codegen__deduped_taco_11.1k.parquet" + "ifbench__fixed_85.6k.parquet" + "logic__arcagi1_297.parquet" + "logic__arcagi2_653.parquet" + "logic__barc_3.4k.parquet" + "logic__graph_logical_dataset_1.4k.parquet" + "logic__ordering_puzzle_dataset_2.9k.parquet" + "logic__reasoning_gym_40.6k.parquet" + "logic__synlogic_12.1k.parquet" + "logic__zebra_puzzle_dataset_5.0k.parquet" + "math__combined_118.2k.part1.parquet" + "math__combined_118.2k.part2.parquet" + "omni_math_4.43k_dedup.parquet" + "simulation__codeio_fixed_12.1k.parquet" + "stem__nemotron_13.3k.parquet" + "stem__web_31.7k.parquet" + "table__hitab_7.4k.parquet" + "table__multihier_2.9k.parquet" +) + +echo "Collecting training files from ${DATA_MIX_DIR}..." + +# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions" +for dataset in "${dataset_names[@]}"; do + for subdir in "main_questions"; do + file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" + if [ -f "$file_path" ]; then + echo "Adding: $file_path" + train_file_list+=("'$file_path'") + fi + done +done + +# for dataset in "${dataset_names[@]}"; do +# for subdir in "131k_context_questions"; do +# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" +# if [ -f "$file_path" ]; then +# echo "Adding: $file_path" +# id_val_file_list+=("'$file_path'") +# fi +# done +# done +# id_val_file_list+=("'$iq400_path'") + +# Join with comma to form Python list string +IFS=, +train_files="[${train_file_list[*]}]" +# id_val_files="[${id_val_file_list[*]}]" +unset IFS + +echo "Total training files found: ${#train_file_list[@]}" +# echo "Total ID validation files found: ${#id_val_file_list[@]}" + +# Test Data Configuration +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet +# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet + +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# Instruction follow (test) +if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet + +# Focused data mixture (math, code, stem) +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Full data mixture (uncomment to use) +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${synlogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}' + + +# =================== Model =================== +# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT +BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k +# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/cispo-k2p-newFiltered-32k-mainQs-finalInstruct-406513/global_step_100/actor/huggingface + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} +# WANDB_EXPERIMENT_NAME="cispo-k2p-newFiltered-32k-mainQs-finalInstruct-406513" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=2.0 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 64)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 12)) +overlong_penalty_factor=1.0 + +loss_mode="cispo" #Default is 'vanilla' which is equivalent to PPO +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=128 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=128 # model grad update batchsize + +# Algorithm +temperature=1.2 +val_temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=16 # Reduced from 32 to reduce memory pressure +gen_tp=4 +gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=True \ + actor_rollout_ref.model.use_liger=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 + # data.id_val_files="$id_val_files" \ \ No newline at end of file diff --git a/scripts/train/k2p_hero_grpo.sh b/scripts/train/k2p_hero_grpo.sh index e0ec0988..5908854e 100644 --- a/scripts/train/k2p_hero_grpo.sh +++ b/scripts/train/k2p_hero_grpo.sh @@ -1,5 +1,5 @@ #!/bin/bash -#SBATCH --job-name=grpo-k2p-finalInstruct-64k-temp1.2-focused +#SBATCH --job-name=grpo-k2p-32k264k-stage2-focused #SBATCH --nodes=64 #SBATCH --ntasks=64 #SBATCH --ntasks-per-node=1 @@ -13,10 +13,8 @@ #SBATCH --partition=main #SBATCH --exclude=azure-uk-hpc-H200-instance-114,azure-uk-hpc-H200-instance-394 -# SBATCH --job-name=grpo-hero-k2p-finalInstruct-temp1.2-wOmni-fix2 - # =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="grpo-k2p-finalInstruct-64k-temp1.2-focused-404084" # Fill in the checkpoint directory name to resume from, otherwise from scratch +RESUME_CKPT_DIR_NAME="grpo-k2p-32k264k-stage2-focused-404083" # Fill in the checkpoint directory name to resume from, otherwise from scratch export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-009:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain # =================== Cluster Environment =================== diff --git a/scripts/train/k2p_hero_grpo_newData.sh b/scripts/train/k2p_hero_grpo_newData.sh index ade66887..7af990cd 100644 --- a/scripts/train/k2p_hero_grpo_newData.sh +++ b/scripts/train/k2p_hero_grpo_newData.sh @@ -1,5 +1,5 @@ #!/bin/bash -#SBATCH --job-name=grpo-k2p-newFiltered-64k-fullData-finalInstruct +#SBATCH --job-name=grpo-stage2-k2pRL-7domains-VaradMix #SBATCH --nodes=64 #SBATCH --ntasks=64 #SBATCH --ntasks-per-node=1 @@ -10,15 +10,13 @@ #SBATCH --error=slurm/%x-%j.log #SBATCH --exclusive #SBATCH --time=720:00:00 -#SBATCH --partition=main -#SBATCH --exclude=azure-uk-hpc-H200-instance-114,azure-uk-hpc-H200-instance-394 - -# SBATCH --job-name=grpo-hero-k2p-finalInstruct-temp1.2-wOmni-fix2 +#SBATCH --partition=higherprio +#SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410] # =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch -export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-009:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain -export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-033:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty +RESUME_CKPT_DIR_NAME="grpo-stage2-k2pRL-7domains-VaradMix-415521" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-286:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty # =================== Cluster Environment =================== export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ @@ -42,7 +40,9 @@ SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ SHARP_COLL_ENABLE_SAT=1 \ SHARP_COLL_LOG_LEVEL=3 \ SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ -NCCL_COLLNET_ENABLE=1 +NCCL_COLLNET_ENABLE=1 \ +NCCL_NVLS_ENABLE=0 + # Get the list of allocated nodes nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) @@ -64,6 +64,9 @@ export VLLM_USE_V1=1 # Training Data Configuration DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_1" train_file_list=() +id_val_file_list=() + +iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet" # List of datasets to include (filename only) # Comment out lines to exclude specific datasets @@ -93,9 +96,9 @@ dataset_names=( echo "Collecting training files from ${DATA_MIX_DIR}..." -# Search for each dataset in all subdirectories +# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions" "easy_questions" for dataset in "${dataset_names[@]}"; do - for subdir in "impossible_questions" "131k_context_questions" "main_questions"; do + for subdir in "main_questions"; do file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" if [ -f "$file_path" ]; then echo "Adding: $file_path" @@ -104,12 +107,25 @@ for dataset in "${dataset_names[@]}"; do done done +# for dataset in "${dataset_names[@]}"; do +# for subdir in "131k_context_questions"; do +# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" +# if [ -f "$file_path" ]; then +# echo "Adding: $file_path" +# id_val_file_list+=("'$file_path'") +# fi +# done +# done +# id_val_file_list+=("'$iq400_path'") + # Join with comma to form Python list string IFS=, train_files="[${train_file_list[*]}]" +# id_val_files="[${id_val_file_list[*]}]" unset IFS echo "Total training files found: ${#train_file_list[@]}" +# echo "Total ID validation files found: ${#id_val_file_list[@]}" # Test Data Configuration TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len @@ -150,18 +166,19 @@ if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet # test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" # Full data mixture (uncomment to use) -test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${synlogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${synlogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}' # =================== Model =================== # BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT -BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) # BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k -# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-focused-k2p-finalInstruct-temp1.2-wOmni-fix2-403906/global_step_300/actor/huggingface +BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface # =================== Logging =================== WANDB_PROJECT=k2plus_rl WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} +# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491" # If RESUME_CKPT_DIR is not empty, resume from the checkpoint if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then @@ -222,10 +239,10 @@ rollout_dtype="float16" enable_filter_groups=False filter_groups_metric=acc max_num_gen_batches=10 -train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +train_prompt_bsz=128 # on-policy model update batchsize: train_prompt_bsz * rollout.n gen_prompt_bsz=$((train_prompt_bsz * 1)) n_resp_per_prompt=16 -train_prompt_mini_bsz=256 # model grad update batchsize +train_prompt_mini_bsz=128 # model grad update batchsize # Algorithm temperature=1.2 @@ -298,7 +315,7 @@ offload=True actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ actor_rollout_ref.rollout.enable_chunked_prefill=True \ @@ -323,7 +340,7 @@ offload=True +actor_rollout_ref.model.override_config.embd_pdrop=0. \ +actor_rollout_ref.model.override_config.resid_pdrop=0. \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.model.enable_activation_offload=True \ + actor_rollout_ref.model.enable_activation_offload=${offload} \ actor_rollout_ref.model.use_liger=True \ reward_model.reward_manager=async_multi_process \ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ @@ -332,7 +349,7 @@ offload=True trainer.logger=['console','wandb'] \ trainer.project_name=${WANDB_PROJECT} \ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ - trainer.val_before_train=True \ + trainer.val_before_train=False \ trainer.n_gpus_per_node=8 \ trainer.nnodes=$worker_num \ trainer.save_freq=10 \ @@ -340,4 +357,5 @@ offload=True trainer.total_epochs=5 \ trainer.log_val_generations=50 \ trainer.resume_mode=auto \ - trainer.max_actor_ckpt_to_keep=3 \ No newline at end of file + trainer.max_actor_ckpt_to_keep=3 + # data.id_val_files="$id_val_files" \ \ No newline at end of file diff --git a/scripts/train/k2p_hero_grpo_newNewData.sh b/scripts/train/k2p_hero_grpo_newNewData.sh new file mode 100644 index 00000000..a487eaa9 --- /dev/null +++ b/scripts/train/k2p_hero_grpo_newNewData.sh @@ -0,0 +1,371 @@ +#!/bin/bash +#SBATCH --job-name=grpo-stage2-k2pRL-easy50k-7domains +#SBATCH --nodes=64 +#SBATCH --ntasks=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=higherprio +#SBATCH --exclude=azure-uk-hpc-H200-instance-[028-049] + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-025:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800000 +export OMPI_MCA_coll_hcoll_enable=0 \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== + +# Training Data Configuration +# DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_1" +DATA_MIX_DIR="/lustrefs/users/haonan.li/Reasoning360/final/data-mixtures/easy/data_mix_seven_domains_20000" +train_file_list=() +id_val_file_list=() + +iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet" + +# List of datasets to include (filename only) +# Comment out lines to exclude specific datasets +dataset_names=( + "codegen__deduped_leetcode2k_2.4k_scored.parquet" + "codegen__deduped_livecodebench_599_scored.parquet" + "codegen__deduped_primeintellect_9.6k_scored.parquet" + "codegen__deduped_taco_11.1k_scored.parquet" + "ifbench__fixed_85.6k_scored.parquet" + "logic__arcagi1_297_scored.parquet" + "logic__arcagi2_653_scored.parquet" + "logic__barc_3.4k_scored.parquet" + "logic__graph_logical_dataset_1.4k_scored.parquet" + "logic__ordering_puzzle_dataset_2.9k_scored.parquet" + "logic__reasoning_gym_40.6k_scored.parquet" + "logic__synlogic_12.1k_scored.parquet" + "logic__zebra_puzzle_dataset_5.0k_scored.parquet" + "math__combined_118.2k.part1_scored.parquet" + "math__combined_118.2k.part2_scored.parquet" + "omni_math_4.43k_scored.parquet" + "simulation__codeio_fixed_12.1k_scored.parquet" + "stem__nemotron_13.3k_scored.parquet" + "stem__web_31.7k_scored.parquet" + "table__hitab_7.4k_scored.parquet" + "table__multihier_2.9k_scored.parquet" +) +# "omni_math_4.43k_dedup.parquet" + +echo "Collecting training files from ${DATA_MIX_DIR}/..." + +for dataset in "${dataset_names[@]}"; do + for subdir in "easy_questions" "main_questions" "131k_context_questions"; do + file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" + + if [ -f "$file_path" ]; then + echo "Found: $file_path" + train_file_list+=("'$file_path'") + + # --- NEW ELIF BLOCK --- + # Checks for the file with an extra "_scored" inserted before .parquet + # e.g., changes "file_scored.parquet" to "file_scored_scored.parquet" + elif [ -f "${DATA_MIX_DIR}/${subdir}/${dataset/.parquet/_scored.parquet}" ]; then + alt_file_path="${DATA_MIX_DIR}/${subdir}/${dataset/.parquet/_scored.parquet}" + echo "Found (with extra _scored): $alt_file_path" + train_file_list+=("'$alt_file_path'") + else + echo "Missing: $file_path" + fi + done +done + +# for dataset in "${dataset_names[@]}"; do +# for subdir in "131k_context_questions"; do +# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" +# if [ -f "$file_path" ]; then +# echo "Adding: $file_path" +# id_val_file_list+=("'$file_path'") +# fi +# done +# done +# id_val_file_list+=("'$iq400_path'") + +# Join with comma to form Python list string +IFS=, +train_files="[${train_file_list[*]}]" +# id_val_files="[${id_val_file_list[*]}]" +unset IFS + +echo "Total training files found: ${#train_file_list[@]}" +# echo "Total ID validation files found: ${#id_val_file_list[@]}" + +# Test Data Configuration +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet +# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet + +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# Instruction follow (test) +if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet + +# Focused data mixture (math, code, stem) +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Full data mixture (uncomment to use) +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${synlogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}' + + +# =================== Model =================== +# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k +BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} +# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 64)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 12)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=256 # model grad update batchsize + +# Algorithm +temperature=1.2 +val_temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=16 # Reduced from 32 to reduce memory pressure +gen_tp=4 +gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=True \ + actor_rollout_ref.model.use_liger=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 + # data.id_val_files="$id_val_files" \ \ No newline at end of file diff --git a/verl/trainer/config/data/legacy_data.yaml b/verl/trainer/config/data/legacy_data.yaml index 028405b4..dde4c197 100644 --- a/verl/trainer/config/data/legacy_data.yaml +++ b/verl/trainer/config/data/legacy_data.yaml @@ -10,6 +10,9 @@ use_shm: False # For HDFS path, we provide utils to download it to DRAM and convert it to a local path. train_files: ~/data/rlhf/gsm8k/train.parquet +# ID Validation set parquet. Can be a list or a single file. +# id_val_files: null + # Validation parquet. Can be a list or a single file. val_files: ~/data/rlhf/gsm8k/test.parquet diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index e8d3e6a2..ae31a481 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -280,6 +280,7 @@ def __init__( reward_fn=None, val_reward_fn=None, train_dataset: Optional[Dataset] = None, + # id_val_dataset: Optional[Dataset] = None, val_dataset: Optional[Dataset] = None, collate_fn=None, train_sampler: Optional[Sampler] = None, @@ -299,6 +300,7 @@ def __init__( reward_fn: Function for computing rewards during training. val_reward_fn: Function for computing rewards during validation. train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. + # id_val_dataset (Optional[Dataset], optional): id Validation dataset. Defaults to None. val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. collate_fn: Function to collate data samples into batches. train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. @@ -338,11 +340,12 @@ def __init__( if self.config.algorithm.use_kl_in_reward: self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) - self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) # id_val_dataset - def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): # id_val_dataset """ Creates the train and validation dataloaders. + # Added by Reasoning 360: creates a third dataloader for id validation... To be kept separate from standard validation approach. """ # TODO: we have to make sure the batch size is divisible by the dp size from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler @@ -351,11 +354,16 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl train_dataset = create_rl_dataset( self.config.data.train_files, self.config.data, self.tokenizer, self.processor ) + # if id_val_dataset is None and self.config.data.id_val_files is not None: + # id_val_dataset = create_rl_dataset( + # self.config.data.id_val_files, self.config.data, self.tokenizer, self.processor + # ) if val_dataset is None: val_dataset = create_rl_dataset( self.config.data.val_files, self.config.data, self.tokenizer, self.processor ) - self.train_dataset, self.val_dataset = train_dataset, val_dataset + self.train_dataset, self.val_dataset = train_dataset, val_dataset + # self.id_val_dataset = id_val_dataset if train_sampler is None: train_sampler = create_rl_sampler(self.config.data, self.train_dataset) @@ -376,9 +384,23 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl ) val_batch_size = self.config.data.val_batch_size # Prefer config value if set + # id_val_batch_size = self.config.data.val_batch_size if val_batch_size is None: val_batch_size = len(self.val_dataset) - + # id_val_batch_size = len(self.id_val_dataset) if self.id_val_dataset is not None else 0 + + # if self.id_val_dataset is not None: + # self.id_val_dataloader = StatefulDataLoader( + # dataset=self.id_val_dataset, + # batch_size=id_val_batch_size, + # num_workers=num_workers, + # shuffle=self.config.data.get("id_validation_shuffle", True), + # drop_last=False, + # collate_fn=collate_fn, + # ) + # else: + # self.id_val_dataloader = None + self.val_dataloader = StatefulDataLoader( dataset=self.val_dataset, batch_size=val_batch_size, @@ -389,12 +411,13 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl ) assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + # assert self.id_val_dataloader is None or len(self.id_val_dataloader) >= 1, "id Validation dataloader is empty!" assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" print( f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " f"{len(self.val_dataloader)}" - ) + ) # Size of id val dataloader: {len(self.id_val_dataloader)} total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs @@ -485,9 +508,12 @@ def _get_gen_batch(self, batch: DataProto) -> DataProto: def _validate(self): data_source_lst = [] + # id_data_source_lst = [] reward_extra_infos_dict: dict[str, list] = defaultdict(list) + # id_reward_extra_infos_dict: dict[str, list] = defaultdict(list) # NOTE: added by Reasoning360. dataset_lst = [] + # dataset_id_lst = [] # Lists to collect samples for the table sample_inputs = [] @@ -497,6 +523,112 @@ def _validate(self): sample_turns = [] sample_uids = [] + # sample_id_inputs = [] + # sample_id_outputs = [] + # sample_id_gts = [] + # sample_id_scores = [] + # sample_id_turns = [] + # sample_id_uids = [] + + # if self.id_val_dataloader is not None: + # print("Starting id validation generation...") + # for test_data in self.id_val_dataloader: + # test_batch = DataProto.from_single_dict(test_data) + + # if "uid" not in test_batch.non_tensor_batch: + # test_batch.non_tensor_batch["uid"] = np.array( + # [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object + # ) + + # # repeat test batch + # test_batch = test_batch.repeat( + # repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + # ) + + # # we only do validation on rule-based rm + # if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": + # return {} + + # # Store original inputs + # input_ids = test_batch.batch["input_ids"] + # # TODO: Can we keep special tokens except for padding tokens? + # input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + # sample_id_inputs.extend(input_texts) + # sample_id_uids.extend(test_batch.non_tensor_batch["uid"]) + + # ground_truths = [ + # item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch + # ] + # sample_id_gts.extend(ground_truths) + + # test_gen_batch = self._get_gen_batch(test_batch) + # test_gen_batch.meta_info = { + # "eos_token_id": self.tokenizer.eos_token_id, + # "pad_token_id": self.tokenizer.pad_token_id, + # "recompute_log_prob": False, + # "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, + # "validate": True, + # "global_steps": self.global_steps, + # } + # print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") + + # # pad to be divisible by dp_size + # size_divisor = ( + # self.actor_rollout_wg.world_size + # if not self.async_rollout_mode + # else self.config.actor_rollout_ref.rollout.agent.num_workers + # ) + # test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) + # if not self.async_rollout_mode: + # test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + # else: + # test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + + # # unpad + # test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + + # print("ID Validation generation end") + + # # Store generated outputs + # output_ids = test_output_gen_batch.batch["responses"] + # output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + # sample_id_outputs.extend(output_texts) + + # test_batch = test_batch.union(test_output_gen_batch) + # test_batch.meta_info["validate"] = True + + # # evaluate using reward_function + # if self.val_reward_fn is None: + # raise ValueError("val_reward_fn must be provided for validation.") + # result = self.val_reward_fn(test_batch, return_dict=True) + # reward_tensor = result["reward_tensor"] + # scores = reward_tensor.sum(-1).cpu().tolist() + # sample_id_scores.extend(scores) + + # id_reward_extra_infos_dict["reward"].extend(scores) + # print(f"len id_reward_extra_infos_dict['reward']: {len(id_reward_extra_infos_dict['reward'])}") + # if "reward_extra_info" in result: + # for key, lst in result["reward_extra_info"].items(): + # id_reward_extra_infos_dict[key].extend(lst) + # print(f"len id_reward_extra_infos_dict['{key}']: {len(id_reward_extra_infos_dict[key])}") + + # # NOTE: added by Reasoning360. Collect dataset information. TODO: maybe replicated usage with the data_source_lst and can be removed? + # datasets = [] + # for i in range(reward_tensor.shape[0]): + # dataset = "unknown" + # if "extra_info" in test_batch.non_tensor_batch: + # extra_info = test_batch.non_tensor_batch["extra_info"][i] + # if isinstance(extra_info, dict) and "dataset" in extra_info: + # dataset = extra_info["dataset"] + # datasets.append(dataset) + # dataset_id_lst.append(np.array(datasets)) + + # # collect num_turns of each prompt + # if "__num_turns__" in test_batch.non_tensor_batch: + # sample_id_turns.append(test_batch.non_tensor_batch["__num_turns__"]) + + # id_data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + for test_data in self.val_dataloader: test_batch = DataProto.from_single_dict(test_data) @@ -608,13 +740,37 @@ def _validate(self): dump_path=val_data_dir, ) + # for key_info, lst in id_reward_extra_infos_dict.items(): + # assert len(lst) == 0 or len(lst) == len(sample_id_scores), f"{key_info}: {len(lst)=}, {len(sample_id_scores)=}" + for key_info, lst in reward_extra_infos_dict.items(): assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" # NOTE: Added by Reasoning360: Calculate the mean reward for each data source and dataset + # id_data_sources = np.concatenate(id_data_source_lst, axis=0) + # id_datasets = np.concatenate(dataset_id_lst, axis=0) # Concatenate datasets data_sources = np.concatenate(data_source_lst, axis=0) datasets = np.concatenate(dataset_lst, axis=0) # Concatenate datasets + # id_data_src2var2metric2val = process_validation_metrics(id_data_sources, sample_id_uids, id_reward_extra_infos_dict) + # id_metric_dict = {} + # for data_source, var2metric2val in id_data_src2var2metric2val.items(): + # core_var = "acc" if "acc" in var2metric2val else "reward" + # for var_name, metric2val in var2metric2val.items(): + # n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + # for metric_name, metric_val in metric2val.items(): + # # NOTE: added by Reasoning360. Add std metrics + # if ( + # (var_name == core_var) + # and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best", "std"]) + # and (f"@{n_max}" in metric_name) + # ): + # metric_sec = "InDomain-Eval-core" + # else: + # metric_sec = "InDomain-Eval-aux" + # pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + # id_metric_dict[pfx] = metric_val + data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict) metric_dict = {} for data_source, var2metric2val in data_src2var2metric2val.items(): @@ -654,7 +810,7 @@ def _validate(self): for (data_source, dataset), rewards in data_source_dataset_reward.items(): metric_dict[f"val/test_score/{data_source}/{dataset}"] = np.mean(rewards) - return metric_dict + return metric_dict # id_metric_dict | metric_dict # Union of two dicts def init_workers(self): """Initialize distributed training workers using Ray backend. From 235af1a26fdd48b163a9a899d0a917b3a5aeccda Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Tue, 6 Jan 2026 21:43:43 +0000 Subject: [PATCH 13/20] Updated reward functions from the async branch --- verl/utils/reward_score/cruxeval/utils.py | 31 ++-- verl/utils/reward_score/graph_dataset.py | 71 ++++----- .../reward_score/math_llm_judge/__init__.py | 44 ++---- .../reward_score/math_llm_judge/grader.py | 40 ++--- verl/utils/reward_score/naive_dapo.py | 147 +++++++----------- .../utils/reward_score/prime_math/__init__.py | 2 +- verl/utils/reward_score/puzzles_dataset.py | 66 ++++---- verl/utils/reward_score/zebra_puzzle.py | 45 ++---- 8 files changed, 170 insertions(+), 276 deletions(-) diff --git a/verl/utils/reward_score/cruxeval/utils.py b/verl/utils/reward_score/cruxeval/utils.py index 9eec77cd..f554152c 100644 --- a/verl/utils/reward_score/cruxeval/utils.py +++ b/verl/utils/reward_score/cruxeval/utils.py @@ -9,9 +9,10 @@ import multiprocessing import os import platform -import signal import tempfile +from verl.utils.py_functional import timeout_limit + def check_correctness(check_program, timeout=3): """ @@ -53,12 +54,15 @@ def unsafe_execute(check_program, result, timeout): # Run program. try: - exec_globals = {} - with swallow_io(): - with time_limit(timeout): + @timeout_limit(seconds=timeout) + def _exec_with_timeout(): + exec_globals = {} + with swallow_io(): exec(check_program, exec_globals) + + _exec_with_timeout() result.append("passed") - except TimeoutException: + except TimeoutError: result.append("timed out") except BaseException as e: result.append(f"failed: {e}") @@ -69,19 +73,6 @@ def unsafe_execute(check_program, result, timeout): os.chdir = chdir -@contextlib.contextmanager -def time_limit(seconds): - def signal_handler(signum, frame): - raise TimeoutException("Timed out!") - - signal.setitimer(signal.ITIMER_REAL, seconds) - signal.signal(signal.SIGALRM, signal_handler) - try: - yield - finally: - signal.setitimer(signal.ITIMER_REAL, 0) - - @contextlib.contextmanager def swallow_io(): stream = WriteOnlyStringIO() @@ -98,10 +89,6 @@ def create_tempdir(): yield dirname -class TimeoutException(Exception): - pass - - class WriteOnlyStringIO(io.StringIO): """StringIO that throws an exception when it's read from""" diff --git a/verl/utils/reward_score/graph_dataset.py b/verl/utils/reward_score/graph_dataset.py index 3bff2554..831b4821 100644 --- a/verl/utils/reward_score/graph_dataset.py +++ b/verl/utils/reward_score/graph_dataset.py @@ -2,70 +2,63 @@ import random import ast import operator -import signal -import contextlib -class TimeoutException(Exception): - pass +from verl.utils.py_functional import timeout_limit -@contextlib.contextmanager -def time_limit(seconds: float): - def signal_handler(signum, frame): - raise TimeoutException("Timed out!") - - signal.setitimer(signal.ITIMER_REAL, seconds) - signal.signal(signal.SIGALRM, signal_handler) - try: - yield - finally: - signal.setitimer(signal.ITIMER_REAL, 0) def extract_solution(solution_str): - answer_pattern = r'(.*?)' + answer_pattern = r"(.*?)" match = re.finditer(answer_pattern, solution_str) matches = list(match) if matches: final_answer = matches[-1].group(1).strip() - if re.search(r'^[A-Za-z]+$', final_answer): + if re.search(r"^[A-Za-z]+$", final_answer): return final_answer else: return None -def compute_score(solution_str, ground_truth, extra_info: any = None, timeout: float = 10.0): + +def compute_score( + solution_str, ground_truth, extra_info: any = None, timeout: float = 10.0 +): """The scoring function for graph dataset task. - + Args: solution_str: the solution text ground_truth: the correct answer timeout: maximum time in seconds to allow for computation """ - try: - with time_limit(timeout): - if not isinstance(ground_truth, str): - ground_truth = str(ground_truth) - target = ground_truth.lower() - solution = extract_solution(solution_str) - - if solution: - solution = solution.lower() - else: - score = 0.0 + @timeout_limit(seconds=timeout) + def _compute_with_timeout(): + if not isinstance(ground_truth, str): + ground_truth_str = str(ground_truth) + else: + ground_truth_str = ground_truth + + target = ground_truth_str.lower() + solution = extract_solution(solution_str) - try: - if target == solution: - score = 1.0 - else: - score = 0.0 + if solution: + solution = solution.lower() + else: + return 0.0 - except Exception as e: - score = 0.0 + try: + if target == solution: + return 1.0 + else: + return 0.0 + except Exception as e: + return 0.0 - except TimeoutException: + try: + score = _compute_with_timeout() + except TimeoutError: print("Computation timed out in graph_dataset") - score = 0.0 + score = 0.0 except Exception as e: print(f"Error in compute_score in graph_dataset: {e}") score = 0.0 diff --git a/verl/utils/reward_score/math_llm_judge/__init__.py b/verl/utils/reward_score/math_llm_judge/__init__.py index 3d78e1c6..e7b9224d 100644 --- a/verl/utils/reward_score/math_llm_judge/__init__.py +++ b/verl/utils/reward_score/math_llm_judge/__init__.py @@ -36,15 +36,17 @@ """ import re +import math + import sympy from pylatexenc import latex2text from sympy.parsing import sympy_parser -import os +import requests +from verl.utils.py_functional import timeout_limit + from . import math_normalize from .grader import math_equal -import requests - # import math_normalize # from grader import math_equal @@ -54,31 +56,6 @@ TUPLE_CHARS = "()[]" -def timeout(timeout_seconds: int = 8): - if os.name == "posix": - import signal - - def decorator(func): - - def handler(signum, frame): - raise TimeoutError("Operation timed out!") - - def wrapper(*args, **kwargs): - old_handler = signal.getsignal(signal.SIGALRM) - signal.signal(signal.SIGALRM, handler) - signal.alarm(timeout_seconds) - - try: - return func(*args, **kwargs) - finally: - signal.alarm(0) - signal.signal(signal.SIGALRM, old_handler) - - return wrapper - - return decorator - else: - raise NotImplementedError(f"Unsupported OS: {os.name}") def _sympy_parse(expr: str): @@ -255,7 +232,7 @@ def should_allow_eval(expr: str): return True -@timeout(timeout_seconds=10) +@timeout_limit(seconds=10) def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): are_equal = False try: @@ -332,7 +309,10 @@ def grade_answer(given_answer: str, ground_truth: str) -> bool: # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) is_correct = False else: - is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + try: + is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + except TimeoutError: + is_correct = False if not is_correct: break @@ -392,8 +372,6 @@ def match_answer(response): return is_matched, response -import math - def llm_check_answer(model_output: str, ground_truth: str, question: str) -> bool: # use llm to check if the answer is correct @@ -448,4 +426,4 @@ def compute_score(model_output: str, if is_matched and not is_correct: # use llm to check if the answer is correct is_correct = llm_check_answer(extracted_model_output, ground_truth, question) - return is_correct, 1, extracted_model_output \ No newline at end of file + return is_correct, 1, extracted_model_output diff --git a/verl/utils/reward_score/math_llm_judge/grader.py b/verl/utils/reward_score/math_llm_judge/grader.py index 34f0c7f5..87d8bac6 100644 --- a/verl/utils/reward_score/math_llm_judge/grader.py +++ b/verl/utils/reward_score/math_llm_judge/grader.py @@ -94,7 +94,6 @@ import contextlib import re -import signal import math from math import isclose from typing import Union @@ -102,6 +101,7 @@ from sympy import N, simplify from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr +from verl.utils.py_functional import timeout_limit def is_digit(s): @@ -312,8 +312,10 @@ def symbolic_equal(a, b, tolerance, timeout=10.0): def _parse(s): for f in [parse_expr, parse_latex]: try: - with time_limit(timeout): + @timeout_limit(seconds=timeout) + def _parse_with_timeout(): return f(s) + return _parse_with_timeout() except Exception: pass return s @@ -322,39 +324,25 @@ def _parse(s): b = _parse(b) try: - with time_limit(timeout): - if simplify(a - b) == 0: - return True + @timeout_limit(seconds=timeout) + def _simplify_with_timeout(): + return simplify(a - b) == 0 + if _simplify_with_timeout(): + return True except Exception: pass try: - with time_limit(timeout): - if isclose(N(a), N(b), rel_tol=tolerance): - return True + @timeout_limit(seconds=timeout) + def _numeric_equal_with_timeout(): + return isclose(N(a), N(b), rel_tol=tolerance) + if _numeric_equal_with_timeout(): + return True except Exception: pass return False -class TimeoutException(Exception): - pass - - -@contextlib.contextmanager -def time_limit(seconds: float): - - def signal_handler(signum, frame): - raise TimeoutException("Timed out!") - - signal.setitimer(signal.ITIMER_REAL, seconds) - signal.signal(signal.SIGALRM, signal_handler) - try: - yield - finally: - signal.setitimer(signal.ITIMER_REAL, 0) - - def format_intervals(prediction): patterns = { "Interval(": r"^Interval\((.*)\)$", diff --git a/verl/utils/reward_score/naive_dapo.py b/verl/utils/reward_score/naive_dapo.py index f819048a..8ad39d01 100644 --- a/verl/utils/reward_score/naive_dapo.py +++ b/verl/utils/reward_score/naive_dapo.py @@ -14,34 +14,18 @@ # Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py import re -import signal from typing import Optional +import math import sympy from pylatexenc import latex2text from sympy.parsing import sympy_parser -import os +from verl.utils.py_functional import timeout_limit + from .prime_math import math_normalize from .prime_math.grader import math_equal -class timeout: - - def __init__(self, seconds=1, error_message="Timeout"): - self.seconds = seconds - self.error_message = error_message - - def handle_timeout(self, signum, frame): - raise TimeoutError(self.error_message) - - def __enter__(self): - signal.signal(signal.SIGALRM, self.handle_timeout) - signal.alarm(self.seconds) - - def __exit__(self, type, value, traceback): - signal.alarm(0) - - # Constants for normalization SUBSTITUTIONS = [ ("an ", ""), @@ -103,10 +87,10 @@ def __exit__(self, type, value, traceback): def normalize_final_answer(final_answer: str) -> str: """Normalize a final answer to a quantitative reasoning question. - + Args: final_answer: The answer string to normalize - + Returns: Normalized answer string """ @@ -148,42 +132,19 @@ def normalize_final_answer(final_answer: str) -> str: TUPLE_CHARS = "()[]" -def timeout(timeout_seconds: int = 8): - if os.name == "posix": - import signal - - def decorator(func): - - def handler(signum, frame): - raise TimeoutError("Operation timed out!") - - def wrapper(*args, **kwargs): - old_handler = signal.getsignal(signal.SIGALRM) - signal.signal(signal.SIGALRM, handler) - signal.alarm(timeout_seconds) - - try: - return func(*args, **kwargs) - finally: - signal.alarm(0) - signal.signal(signal.SIGALRM, old_handler) - - return wrapper - - return decorator - else: - raise NotImplementedError(f"Unsupported OS: {os.name}") - - def _sympy_parse(expr: str): """Parses an expression with sympy.""" py_expr = expr.replace("^", "**") return sympy_parser.parse_expr( py_expr, - transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)), + transformations=( + sympy_parser.standard_transformations + + (sympy_parser.implicit_multiplication_application,) + ), ) +# @timeout(timeout_seconds=5) def _parse_latex(expr: str) -> str: """Attempts to parse latex to an expression sympy can read.""" expr = expr.replace("\\tfrac", "\\frac") @@ -279,23 +240,23 @@ def _normalize(expr: str) -> str: expr = expr.replace("trillion", "*10^12") for unit in [ - "degree", - "cm", - "centimeter", - "meter", - "mile", - "second", - "minute", - "hour", - "day", - "week", - "month", - "year", - "foot", - "feet", - "inch", - "yard", - "liter", + "degree", + "cm", + "centimeter", + "meter", + "mile", + "second", + "minute", + "hour", + "day", + "week", + "month", + "year", + "foot", + "feet", + "inch", + "yard", + "liter", ]: expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) expr = re.sub(f"\^ *\\\\circ", "", expr) @@ -349,7 +310,7 @@ def should_allow_eval(expr: str): return True -@timeout(timeout_seconds=10) +@timeout_limit(seconds=10) def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): are_equal = False try: @@ -359,7 +320,7 @@ def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): simplified = sympy.simplify(sympy_diff) if simplified == 0: are_equal = True - except: + except Exception: pass return are_equal @@ -371,8 +332,12 @@ def split_tuple(expr: str): expr = _strip_properly_formatted_commas(expr) if len(expr) == 0: return [] - if (len(expr) > 2 and expr[0] in TUPLE_CHARS and expr[-1] in TUPLE_CHARS and - all([ch not in expr[1:-1] for ch in TUPLE_CHARS])): + if ( + len(expr) > 2 + and expr[0] in TUPLE_CHARS + and expr[-1] in TUPLE_CHARS + and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) + ): elems = [elem.strip() for elem in expr[1:-1].split(",")] else: elems = [expr] @@ -411,8 +376,10 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: ground_truth_elems = split_tuple(ground_truth_normalized) given_elems = split_tuple(given_normalized) - if len(ground_truth_elems) > 1 and (ground_truth_normalized[0] != given_normalized[0] or - ground_truth_normalized[-1] != given_normalized[-1]): + if len(ground_truth_elems) > 1 and ( + ground_truth_normalized[0] != given_normalized[0] + or ground_truth_normalized[-1] != given_normalized[-1] + ): is_correct = False elif len(ground_truth_elems) != len(given_elems): is_correct = False @@ -426,13 +393,16 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) is_correct = False else: - # is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) - is_correct = False + try: + is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + except TimeoutError: + is_correct = False if not is_correct: break return is_correct, given_normalized + def _last_boxed_only_string(string): idx = string.rfind("\\boxed") if idx < 0: @@ -460,7 +430,7 @@ def _last_boxed_only_string(string): if left_brace_idx is None or right_brace_idx is None: return None - return string[left_brace_idx + 1:right_brace_idx].strip() + return string[left_brace_idx + 1 : right_brace_idx].strip() def match_answer(response): @@ -472,21 +442,21 @@ def match_answer(response): if ans_boxed: is_matched = True response = ans_boxed - + return is_matched, response + import math -def compute_score(solution_str: str, - ground_truth: str, - extra_info: dict) -> float: + +def compute_score(solution_str: str, ground_truth: str, extra_info: dict) -> float: """Compute the reward score for a solution. This draws heavily from the LLM-as-judge and PRIME reward functions - + Args: solution_str: The solution string ground_truth: The ground truth answer extra_info: dict with additional info for the score computation - + Returns: Reward score (1.0 for correct, -1.0 for incorrect) """ @@ -496,30 +466,33 @@ def compute_score(solution_str: str, # Extract answer from generated output is_matched, extracted_model_output = match_answer(model_output) - + # TWK NOTE: WE REMOVED THE RESPONSE TRUNCATION FROM math_dapo.compute_score # Verify the solution, first check simple comparisons. correct, pred = grade_answer(extracted_model_output, ground_truth) - if not correct: + if not correct: try: if "\\pi" in extracted_model_output or "\\pi" in ground_truth: equivs = [] for pi in [math.pi, 3.14]: - equivs.append(math_equal(extracted_model_output, ground_truth, tiemout=True, pi=pi)) + equivs.append( + math_equal( + extracted_model_output, ground_truth, timeout=True, pi=pi + ) + ) correct = any(equivs) else: correct = math_equal(extracted_model_output, ground_truth, timeout=True) except: correct = False - # reward = 1.0 if correct else -1.0 - reward = 1.0 if correct else 0. + reward = 1.0 if correct else 0.0 acc = correct return { "score": reward, "acc": acc, - } \ No newline at end of file + } diff --git a/verl/utils/reward_score/prime_math/__init__.py b/verl/utils/reward_score/prime_math/__init__.py index d0ea47ac..8d9d273e 100644 --- a/verl/utils/reward_score/prime_math/__init__.py +++ b/verl/utils/reward_score/prime_math/__init__.py @@ -55,7 +55,7 @@ def _parse_latex(expr: str) -> str: expr = expr.replace("\\tfrac", "\\frac") expr = expr.replace("\\dfrac", "\\frac") expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. - # expr = latex2text.LatexNodes2Text().latex_to_text(expr) + expr = latex2text.LatexNodes2Text().latex_to_text(expr) # Replace the specific characters that this parser uses. expr = expr.replace("√", "sqrt") diff --git a/verl/utils/reward_score/puzzles_dataset.py b/verl/utils/reward_score/puzzles_dataset.py index 72433430..23401e77 100644 --- a/verl/utils/reward_score/puzzles_dataset.py +++ b/verl/utils/reward_score/puzzles_dataset.py @@ -2,23 +2,8 @@ import random import ast import operator -import signal -import contextlib -class TimeoutException(Exception): - pass - -@contextlib.contextmanager -def time_limit(seconds: float): - def signal_handler(signum, frame): - raise TimeoutException("Timed out!") - - signal.setitimer(signal.ITIMER_REAL, seconds) - signal.signal(signal.SIGALRM, signal_handler) - try: - yield - finally: - signal.setitimer(signal.ITIMER_REAL, 0) +from verl.utils.py_functional import timeout_limit def extract_solution(solution_str): @@ -92,30 +77,33 @@ def compute_score(solution_str, ground_truth, extra_info: any = None, method='st method: the method to extract the solution timeout: maximum time in seconds to allow for computation """ - try: - with time_limit(timeout): - target = ground_truth.tolist() if not isinstance(ground_truth,list) else ground_truth - predicted_arrangement = extract_solution(solution_str=solution_str) - - if predicted_arrangement is None: - score = 0.0 - - # Evaluate equation - try: - if isinstance(predicted_arrangement, list) and isinstance(target, list): - edit_distance = compute_edit_distance(predicted_arrangement, target) - max_possible_dist = max(len(predicted_arrangement), len(target)) - result = predicted_arrangement == target - if result: - score = 1.0 - elif method != 'strict': - score = max(1.0 - (edit_distance / max_possible_dist)) - else: - score = 0.0 - except Exception as e: - score = 0.0 + @timeout_limit(seconds=timeout) + def _compute_with_timeout(): + target = ground_truth.tolist() if not isinstance(ground_truth,list) else ground_truth + predicted_arrangement = extract_solution(solution_str=solution_str) + + if predicted_arrangement is None: + return 0.0 - except TimeoutException: + # Evaluate equation + try: + if isinstance(predicted_arrangement, list) and isinstance(target, list): + edit_distance = compute_edit_distance(predicted_arrangement, target) + max_possible_dist = max(len(predicted_arrangement), len(target)) + result = predicted_arrangement == target + if result: + return 1.0 + elif method != 'strict': + return max(1.0 - (edit_distance / max_possible_dist)) + else: + return 0.0 + except Exception as e: + return 0.0 + + score = 0.0 + try: + score = _compute_with_timeout() + except TimeoutError: print("Computation timed out in puzzles_dataset") score = 0.0 except Exception as e: diff --git a/verl/utils/reward_score/zebra_puzzle.py b/verl/utils/reward_score/zebra_puzzle.py index 01da4d67..90ac730d 100644 --- a/verl/utils/reward_score/zebra_puzzle.py +++ b/verl/utils/reward_score/zebra_puzzle.py @@ -3,23 +3,8 @@ import ast import operator import json -import signal -import contextlib -class TimeoutException(Exception): - pass - -@contextlib.contextmanager -def time_limit(seconds: float): - def signal_handler(signum, frame): - raise TimeoutException("Timed out!") - - signal.setitimer(signal.ITIMER_REAL, seconds) - signal.signal(signal.SIGALRM, signal_handler) - try: - yield - finally: - signal.setitimer(signal.ITIMER_REAL, 0) +from verl.utils.py_functional import timeout_limit def extract_solution(solution_str): @@ -68,20 +53,22 @@ def compute_accuracy(answer, ground_truth): return accuracy def compute_score(solution_str, ground_truth, extra_info: any = None, method='strict', timeout: float = 10.0): - try: - with time_limit(timeout): - predicted_arrangement = extract_solution(solution_str) + @timeout_limit(seconds=timeout) + def _compute_with_timeout(): + predicted_arrangement = extract_solution(solution_str) - if predicted_arrangement is None: - score = 0.0 - else: - try: - accuracy = compute_accuracy(predicted_arrangement, ground_truth) - score = accuracy - except Exception as e: - score = 0.0 - - except TimeoutException: + if predicted_arrangement is None: + return 0.0 + else: + try: + accuracy = compute_accuracy(predicted_arrangement, ground_truth) + return accuracy + except Exception as e: + return 0.0 + + try: + score = _compute_with_timeout() + except TimeoutError: print("Computation timed out in zebra_puzzle") score = 0.0 except Exception as e: From 68c4492b5d014b945241d781a131ba1c3a8f4fea Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Thu, 8 Jan 2026 03:04:56 +0000 Subject: [PATCH 14/20] sync --- scripts/train/k2p_hero_grpo_newData.sh | 33 +++++++++++++------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/scripts/train/k2p_hero_grpo_newData.sh b/scripts/train/k2p_hero_grpo_newData.sh index 7af990cd..5b6d599b 100644 --- a/scripts/train/k2p_hero_grpo_newData.sh +++ b/scripts/train/k2p_hero_grpo_newData.sh @@ -76,23 +76,23 @@ dataset_names=( "codegen__deduped_primeintellect_9.6k.parquet" "codegen__deduped_taco_11.1k.parquet" "ifbench__fixed_85.6k.parquet" - "logic__arcagi1_297.parquet" - "logic__arcagi2_653.parquet" - "logic__barc_3.4k.parquet" - "logic__graph_logical_dataset_1.4k.parquet" - "logic__ordering_puzzle_dataset_2.9k.parquet" - "logic__reasoning_gym_40.6k.parquet" - "logic__synlogic_12.1k.parquet" - "logic__zebra_puzzle_dataset_5.0k.parquet" "math__combined_118.2k.part1.parquet" "math__combined_118.2k.part2.parquet" "omni_math_4.43k_dedup.parquet" - "simulation__codeio_fixed_12.1k.parquet" "stem__nemotron_13.3k.parquet" "stem__web_31.7k.parquet" "table__hitab_7.4k.parquet" "table__multihier_2.9k.parquet" ) +# # "simulation__codeio_fixed_12.1k.parquet" +# "logic__arcagi1_297.parquet" +# "logic__arcagi2_653.parquet" +# "logic__barc_3.4k.parquet" +# "logic__graph_logical_dataset_1.4k.parquet" +# "logic__ordering_puzzle_dataset_2.9k.parquet" +# "logic__reasoning_gym_40.6k.parquet" +# "logic__synlogic_12.1k.parquet" +# "logic__zebra_puzzle_dataset_5.0k.parquet" echo "Collecting training files from ${DATA_MIX_DIR}..." @@ -239,10 +239,10 @@ rollout_dtype="float16" enable_filter_groups=False filter_groups_metric=acc max_num_gen_batches=10 -train_prompt_bsz=128 # on-policy model update batchsize: train_prompt_bsz * rollout.n +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n gen_prompt_bsz=$((train_prompt_bsz * 1)) n_resp_per_prompt=16 -train_prompt_mini_bsz=128 # model grad update batchsize +train_prompt_mini_bsz=256 # model grad update batchsize # Algorithm temperature=1.2 @@ -315,7 +315,7 @@ offload=True actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ actor_rollout_ref.rollout.enable_chunked_prefill=True \ @@ -342,7 +342,7 @@ offload=True actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.model.enable_activation_offload=${offload} \ actor_rollout_ref.model.use_liger=True \ - reward_model.reward_manager=async_multi_process \ + reward_model.reward_manager=dapo \ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ reward_model.overlong_buffer.len=${overlong_buffer_len} \ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ @@ -352,10 +352,11 @@ offload=True trainer.val_before_train=False \ trainer.n_gpus_per_node=8 \ trainer.nnodes=$worker_num \ - trainer.save_freq=10 \ + trainer.save_freq=1 \ trainer.test_freq=5 \ trainer.total_epochs=5 \ - trainer.log_val_generations=50 \ + trainer.log_val_generations=0 \ trainer.resume_mode=auto \ trainer.max_actor_ckpt_to_keep=3 - # data.id_val_files="$id_val_files" \ \ No newline at end of file + # data.id_val_files="$id_val_files" \ + # trainer.log_val_generations=50 \ \ No newline at end of file From e1325140ade1c295f15f77d161256a731085c35c Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Fri, 9 Jan 2026 00:27:00 +0000 Subject: [PATCH 15/20] Synch --- scripts/tools/serve_llm_as_verifier.sh | 2 +- scripts/tools/serve_math_llm_as_verifier.sh | 22 ++ scripts/train/k2p_hero_grpo_newData.sh | 14 +- verl/utils/reward_score/arcagi.py | 45 ++-- verl/utils/reward_score/codeio.py | 16 +- verl/utils/reward_score/ifbench/__init__.py | 96 +++++---- verl/utils/reward_score/ifeval/__init__.py | 71 ++++--- .../reward_score/math_llm_judge/__init__.py | 1 + .../reward_score/reasoning_gym/__init__.py | 184 +++++++++-------- .../synlogic/arrow_maze_verifier.py | 180 ++++++++-------- .../synlogic/boolean_expressions_verifier.py | 30 +-- .../synlogic/campsite_verifier.py | 62 +++--- .../synlogic/dyck_language_errors_verifier.py | 66 +++--- ...dyck_language_reasoning_errors_verifier.py | 66 +++--- .../synlogic/dyck_language_verifier.py | 36 ++-- .../synlogic/game_of_buggy_tables_verifier.py | 34 ++-- .../synlogic/goods_exchange_verifier.py | 46 +++-- .../synlogic/math_path_verifier.py | 103 +++++----- .../synlogic/minesweeper_verifier.py | 26 ++- .../synlogic/norinori_verifier.py | 94 +++++---- .../synlogic/number_wall_verifier.py | 88 ++++---- .../reward_score/synlogic/numbrix_verifier.py | 92 +++++---- .../synlogic/object_counting_verifier.py | 26 ++- .../synlogic/object_properties_verifier.py | 18 +- .../synlogic/operation_verifier.py | 20 +- .../synlogic/skyscraper_puzzle_verifier.py | 166 +++++++-------- .../synlogic/space_reasoning_tree_verifier.py | 22 +- .../synlogic/space_reasoning_verifier.py | 14 +- .../star_placement_puzzle_verifier.py | 152 +++++++------- .../synlogic/time_sequence_verifier.py | 55 ++--- .../synlogic/web_of_lies_verifier.py | 59 +++--- .../synlogic/word_sorting_mistake_verifier.py | 28 +-- .../synlogic/word_sorting_verifier.py | 16 +- .../synlogic/wordscapes_verifier.py | 192 ++++++++++-------- 34 files changed, 1189 insertions(+), 953 deletions(-) create mode 100644 scripts/tools/serve_math_llm_as_verifier.sh diff --git a/scripts/tools/serve_llm_as_verifier.sh b/scripts/tools/serve_llm_as_verifier.sh index 0d901938..1f4345b1 100644 --- a/scripts/tools/serve_llm_as_verifier.sh +++ b/scripts/tools/serve_llm_as_verifier.sh @@ -1,6 +1,6 @@ #!/bin/bash #SBATCH --job-name=server_llm_as_verifier -#SBATCH --partition=main +#SBATCH --partition=higherprio #SBATCH --nodes=1 #SBATCH --ntasks=1 #SBATCH --cpus-per-task=64 diff --git a/scripts/tools/serve_math_llm_as_verifier.sh b/scripts/tools/serve_math_llm_as_verifier.sh new file mode 100644 index 00000000..60598825 --- /dev/null +++ b/scripts/tools/serve_math_llm_as_verifier.sh @@ -0,0 +1,22 @@ +#!/bin/bash +#SBATCH --job-name=server_math_llm_as_verifier +#SBATCH --partition=higherprio +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=64 +#SBATCH --gres=gpu:8 +#SBATCH --time=720:00:00 +#SBATCH --output=slurm/serve_math_llm_as_verifier_%j.log +#SBATCH --error=slurm/serve_math_llm_as_verifier_%j.log + + +# (1) detect this node’s primary IP +NODE_IP=$(hostname -I | awk '{print $1}') +echo "Detected NODE_IP = $NODE_IP" + +# (2) export judge URL for downstream clients +export MATH_LLM_JUDGE_URL="http://${NODE_IP}:8000" +echo "MATH_LLM_JUDGE_URL=$MATH_LLM_JUDGE_URL" + +# (3) launch the vLLM server bound to that IP +vllm serve openai/gpt-oss-120b --host "$NODE_IP" --data-parallel-size 8 --enable-expert-parallel \ No newline at end of file diff --git a/scripts/train/k2p_hero_grpo_newData.sh b/scripts/train/k2p_hero_grpo_newData.sh index 5b6d599b..fcee39a9 100644 --- a/scripts/train/k2p_hero_grpo_newData.sh +++ b/scripts/train/k2p_hero_grpo_newData.sh @@ -1,5 +1,5 @@ #!/bin/bash -#SBATCH --job-name=grpo-stage2-k2pRL-7domains-VaradMix +#SBATCH --job-name=grpo-stage2-k2pRL-easy50k-7domains #SBATCH --nodes=64 #SBATCH --ntasks=64 #SBATCH --ntasks-per-node=1 @@ -11,12 +11,12 @@ #SBATCH --exclusive #SBATCH --time=720:00:00 #SBATCH --partition=higherprio -#SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410] +# commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410] # =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="grpo-stage2-k2pRL-7domains-VaradMix-415521" # Fill in the checkpoint directory name to resume from, otherwise from scratch +RESUME_CKPT_DIR_NAME="grpo-stage2-k2pRL-easy50k-7domains-415354" # Fill in the checkpoint directory name to resume from, otherwise from scratch export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain -export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-286:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty +export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty # =================== Cluster Environment =================== export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ @@ -84,8 +84,8 @@ dataset_names=( "table__hitab_7.4k.parquet" "table__multihier_2.9k.parquet" ) -# # "simulation__codeio_fixed_12.1k.parquet" -# "logic__arcagi1_297.parquet" +# "simulation__codeio_fixed_12.1k.parquet" +# "logic__arcagi1_297.parquet" # "logic__arcagi2_653.parquet" # "logic__barc_3.4k.parquet" # "logic__graph_logical_dataset_1.4k.parquet" @@ -166,7 +166,7 @@ if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet # test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" # Full data mixture (uncomment to use) -test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${synlogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}' +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}', # =================== Model =================== diff --git a/verl/utils/reward_score/arcagi.py b/verl/utils/reward_score/arcagi.py index b8367d89..198db591 100644 --- a/verl/utils/reward_score/arcagi.py +++ b/verl/utils/reward_score/arcagi.py @@ -1,9 +1,11 @@ import re import ast import numpy as np +from verl.utils.py_functional import timeout_limit + def extract_solution(solution_str): - answer_pattern = r'(.*?)' + answer_pattern = r"(.*?)" matches = list(re.finditer(answer_pattern, solution_str, flags=re.DOTALL)) if matches: final_answer = matches[-1].group(1).strip() @@ -11,8 +13,8 @@ def extract_solution(solution_str): final_answer = final_answer.replace("...", "-1") try: # Find the part of the text that looks like a nested list - start = final_answer.index('[[') - end = final_answer.index(']]', start) + 2 + start = final_answer.index("[[") + end = final_answer.index("]]", start) + 2 array_str = final_answer[start:end] # Use ast.literal_eval to safely evaluate the string as a Python expression array = ast.literal_eval(array_str) @@ -25,6 +27,7 @@ def extract_solution(solution_str): else: return [[0]] + def pad_array_with_value(array, target_shape, pad_value): """ Pad the given array to the target shape with the specified pad value. @@ -53,7 +56,7 @@ def pad_array_with_value(array, target_shape, pad_value): except Exception as e: array = np.array([[0]]) original_shape = array.shape - padded_array[:original_shape[0], :original_shape[1]] = array + padded_array[: original_shape[0], : original_shape[1]] = array return padded_array @@ -74,19 +77,37 @@ def compare_solutions_with_padding(generated_output, correct_output, pad_value=- max_rows = max(len(generated_output), len(correct_output)) max_cols = max(len(generated_output[0]), len(correct_output[0])) target_shape = (max_rows, max_cols) - + padded_generated = pad_array_with_value(generated_output, target_shape, pad_value) padded_correct = pad_array_with_value(correct_output, target_shape, pad_value) total_pixels = max_rows * max_cols - correct_pixels = np.sum((padded_generated == padded_correct) & (padded_generated != pad_value) & (padded_correct != pad_value)) - correct_percentage = (correct_pixels / total_pixels) + correct_pixels = np.sum( + (padded_generated == padded_correct) + & (padded_generated != pad_value) + & (padded_correct != pad_value) + ) + correct_percentage = correct_pixels / total_pixels is_correct = float(correct_pixels == total_pixels) return is_correct, correct_percentage +def compute_score( + model_output: str, ground_truth: np.ndarray, extra_info: any = None +) -> float: + @timeout_limit(seconds=10) + def _compute_score_with_timeout(): + model_output_str = str(model_output) + final_answer = extract_solution(model_output_str) + is_correct, correct_percentage = compare_solutions_with_padding( + final_answer, ground_truth + ) + return {"score": is_correct, "acc": is_correct} -def compute_score(model_output: str, ground_truth: np.ndarray, extra_info: any = None) -> float: - model_output = str(model_output) - final_answer = extract_solution(model_output) - is_correct, correct_percentage = compare_solutions_with_padding(final_answer, ground_truth) - return {"score": is_correct, "acc": is_correct} \ No newline at end of file + try: + return _compute_score_with_timeout() + except TimeoutError: + print("Computation timed out in arcagi") + return {"score": 0.0, "acc": 0.0} + except Exception as e: + print(f"Error in compute_score in arcagi: {e}") + return {"score": 0.0, "acc": 0.0} diff --git a/verl/utils/reward_score/codeio.py b/verl/utils/reward_score/codeio.py index 44955ae4..67b8b19b 100644 --- a/verl/utils/reward_score/codeio.py +++ b/verl/utils/reward_score/codeio.py @@ -2,6 +2,7 @@ import ast import re from typing import Dict, Any, Tuple, List +from verl.utils.py_functional import timeout_limit def normalize(obj: Any) -> Any: """ @@ -139,8 +140,19 @@ def compute_score(model_output: str, ground_truth: str, extra_info: any = None) """ Compute score dict for evaluation harness. """ - correct, _ = check_accuracy(str(model_output), str(ground_truth), any_order=False) - return {"score": correct, "acc": correct} + @timeout_limit(seconds=10) + def _compute_score_with_timeout(): + correct, _ = check_accuracy(str(model_output), str(ground_truth), any_order=False) + return {"score": correct, "acc": correct} + + try: + return _compute_score_with_timeout() + except TimeoutError: + print("Computation timed out in codeio") + return {"score": False, "acc": False} + except Exception as e: + print(f"Error in compute_score in codeio: {e}") + return {"score": False, "acc": False} # --------------------------- test --------------------------- # diff --git a/verl/utils/reward_score/ifbench/__init__.py b/verl/utils/reward_score/ifbench/__init__.py index 5d53be4d..1369593e 100644 --- a/verl/utils/reward_score/ifbench/__init__.py +++ b/verl/utils/reward_score/ifbench/__init__.py @@ -2,6 +2,7 @@ import json import numpy as np +from verl.utils.py_functional import timeout_limit from .instructions_registry import INSTRUCTION_DICT @@ -18,50 +19,61 @@ def compute_score(solution_str, ground_truth, extra_info=None): Returns: dict: {"score": float, "acc": bool} """ - # Strip off any thinking section - if "" in solution_str: - answer = solution_str.split("", 1)[1].strip() - else: - answer = solution_str.strip() + @timeout_limit(seconds=30) + def _compute_score_with_timeout(): + # Strip off any thinking section + if "" in solution_str: + answer = solution_str.split("", 1)[1].strip() + else: + answer = solution_str.strip() - # Parse ground_truth if it's a string - if isinstance(ground_truth, str): - try: - gt_list = ast.literal_eval(ground_truth) - except Exception: - gt_list = json.loads(ground_truth) - else: - gt_list = ground_truth + # Parse ground_truth if it's a string + if isinstance(ground_truth, str): + try: + gt_list = ast.literal_eval(ground_truth) + except Exception: + gt_list = json.loads(ground_truth) + else: + gt_list = ground_truth - # Take the first set of constraints - if not isinstance(gt_list, list) or not gt_list: - return {"score": 0.0, "acc": False} - first_item = gt_list[0] - instruction_ids = first_item.get("instruction_id", []) - kwargs_list = first_item.get("kwargs", []) + # Take the first set of constraints + if not isinstance(gt_list, list) or not gt_list: + return {"score": 0.0, "acc": False} + first_item = gt_list[0] + instruction_ids = first_item.get("instruction_id", []) + kwargs_list = first_item.get("kwargs", []) + + # Evaluate each instruction + results = [] + for instr_id, raw_args in zip(instruction_ids, kwargs_list): + # Prepare args dict + args = {} if raw_args is None else raw_args + # Convert numpy and floats + clean_args = {} + for key, val in args.items(): + if isinstance(val, float): + clean_args[key] = int(val) + elif isinstance(val, np.ndarray): + clean_args[key] = val.tolist() + else: + clean_args[key] = val - # Evaluate each instruction - results = [] - for instr_id, raw_args in zip(instruction_ids, kwargs_list): - # Prepare args dict - args = {} if raw_args is None else raw_args - # Convert numpy and floats - clean_args = {} - for key, val in args.items(): - if isinstance(val, float): - clean_args[key] = int(val) - elif isinstance(val, np.ndarray): - clean_args[key] = val.tolist() - else: - clean_args[key] = val + # Build and check instruction + instr_cls = INSTRUCTION_DICT[instr_id] + instr = instr_cls(instr_id) + instr.build_description(**clean_args) + passed = bool(answer and instr.check_following(answer)) + results.append(passed) - # Build and check instruction - instr_cls = INSTRUCTION_DICT[instr_id] - instr = instr_cls(instr_id) - instr.build_description(**clean_args) - passed = bool(answer and instr.check_following(answer)) - results.append(passed) + # Return 1.0 if all constraints are satisfied, 0.0 otherwise + score = 1.0 if all(results) else 0.0 + return {"score": score, "acc": score == 1.0} - # Return 1.0 if all constraints are satisfied, 0.0 otherwise - score = 1.0 if all(results) else 0.0 - return {"score": score, "acc": score == 1.0} + try: + return _compute_score_with_timeout() + except TimeoutError: + print("Computation timed out in ifbench") + return {"score": 0.0, "acc": False} + except Exception as e: + print(f"Error in compute_score in ifbench: {e}") + return {"score": 0.0, "acc": False} diff --git a/verl/utils/reward_score/ifeval/__init__.py b/verl/utils/reward_score/ifeval/__init__.py index 3f6c312f..3d651190 100644 --- a/verl/utils/reward_score/ifeval/__init__.py +++ b/verl/utils/reward_score/ifeval/__init__.py @@ -1,5 +1,7 @@ from verl.utils.reward_score.ifeval import instructions_registry import numpy as np +from verl.utils.py_functional import timeout_limit + def compute_score(solution_str, ground_truth, extra_info): """The scoring function for IFEval. @@ -12,36 +14,47 @@ def compute_score(solution_str, ground_truth, extra_info): format_score: the score for the format score: the score for the correct answer """ - if "" in solution_str: - answer = solution_str.split("")[1] - else: - answer = solution_str - is_following_list = [] - for index, instruction_id in enumerate(extra_info["instruction_id_list"]): - instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id] - instruction = instruction_cls(instruction_id) + @timeout_limit(seconds=60) + def _compute_score_with_timeout(): + if "" in solution_str: + answer = solution_str.split("")[1] + else: + answer = solution_str + is_following_list = [] + for index, instruction_id in enumerate(extra_info["instruction_id_list"]): + instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id] + instruction = instruction_cls(instruction_id) - # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method. - # print(ground_truth) - # print(extra_info["instruction_id_list"]) - # print(len(extra_info["instruction_id_list"])) - # if v is , turn it into a list - # if v is float, turn it into an int - kwargs = { - k: int(v) if isinstance(v, float) else v.tolist() if isinstance(v, np.ndarray) else v for k, v in ground_truth[index].items() if v is not None - } + # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method. + # print(ground_truth) + # print(extra_info["instruction_id_list"]) + # print(len(extra_info["instruction_id_list"])) + # if v is , turn it into a list + # if v is float, turn it into an int + kwargs = { + k: int(v) if isinstance(v, float) else v.tolist() if isinstance(v, np.ndarray) else v for k, v in ground_truth[index].items() if v is not None + } - instruction.build_description(**kwargs) - args = instruction.get_instruction_args() - if args and "prompt" in args: - instruction.build_description(prompt=extra_info["prompt"]) + instruction.build_description(**kwargs) + args = instruction.get_instruction_args() + if args and "prompt" in args: + instruction.build_description(prompt=extra_info["prompt"]) - if answer.strip() and instruction.check_following(answer): - is_following_list.append(True) - else: - is_following_list.append(False) + if answer.strip() and instruction.check_following(answer): + is_following_list.append(True) + else: + is_following_list.append(False) - return { - "score": all(is_following_list), - "acc": all(is_following_list), - } + return { + "score": all(is_following_list), + "acc": all(is_following_list), + } + + try: + return _compute_score_with_timeout() + except TimeoutError: + print("Computation timed out in ifbench") + return {"score": 0.0, "acc": False} + except Exception as e: + print(f"Error in compute_score in ifbench: {e}") + return {"score": 0.0, "acc": False} \ No newline at end of file diff --git a/verl/utils/reward_score/math_llm_judge/__init__.py b/verl/utils/reward_score/math_llm_judge/__init__.py index e7b9224d..e0fba8d5 100644 --- a/verl/utils/reward_score/math_llm_judge/__init__.py +++ b/verl/utils/reward_score/math_llm_judge/__init__.py @@ -36,6 +36,7 @@ """ import re +import os import math import sympy diff --git a/verl/utils/reward_score/reasoning_gym/__init__.py b/verl/utils/reward_score/reasoning_gym/__init__.py index 64eba44c..4ba9139f 100644 --- a/verl/utils/reward_score/reasoning_gym/__init__.py +++ b/verl/utils/reward_score/reasoning_gym/__init__.py @@ -1,6 +1,7 @@ import reasoning_gym import json import re +from verl.utils.py_functional import timeout_limit def compute_score(solution_str, ground_truth, extra_info=None, item=None): """ @@ -15,96 +16,107 @@ def compute_score(solution_str, ground_truth, extra_info=None, item=None): Returns: dict: {"score": float, "acc": float} """ - task = None - entry = None - - # 1. Parse extra_info - extra_info_dict = {} - metadata = None - - if extra_info: - if isinstance(extra_info, str): - try: - extra_info_dict = json.loads(extra_info) - except Exception: - extra_info_dict = {} - else: - extra_info_dict = extra_info - - # Get task first - task = extra_info_dict.get("task") - entry = extra_info_dict.get("entry") - - # Handle metadata field if present - if "metadata" in extra_info_dict: - if isinstance(extra_info_dict["metadata"], str): + @timeout_limit(seconds=10) + def _compute_score_with_timeout(): + task = None + entry = None + + # 1. Parse extra_info + extra_info_dict = {} + metadata = None + + if extra_info: + if isinstance(extra_info, str): try: - metadata = json.loads(extra_info_dict["metadata"]) + extra_info_dict = json.loads(extra_info) except Exception: - metadata = {} - elif isinstance(extra_info_dict["metadata"], dict): - metadata = extra_info_dict["metadata"] - - # 2. Try to get from item (fallback - this is rarely used in actual training) - if not task and item and isinstance(item, dict): - task = item.get("ability") - - # 3. Try to get from ground_truth - if not task and isinstance(ground_truth, dict): - task = ground_truth.get("task") - entry = ground_truth - - if not task: - raise ValueError("task must be provided in extra_info, item, or ground_truth dict.") - - # 4. Get scoring function - scorer = reasoning_gym.get_score_answer_fn(task) - - # 5. Get entry - if entry is None: - entry = {"answer": ground_truth} - - # Build metadata field, prioritizing extra_info metadata - if isinstance(entry, dict): - if "metadata" not in entry or not isinstance(entry["metadata"], dict): - entry["metadata"] = {} - if metadata is not None: - entry["metadata"].update(metadata) - if task is not None: - entry["metadata"]["task"] = task - entry["metadata"]["solution_str"] = solution_str - entry["metadata"]["ground_truth"] = ground_truth - if extra_info is not None: - entry["metadata"]["extra_info"] = extra_info - if item is not None: - entry["metadata"]["item"] = item - - # 6. Extract clean answer from solution_str - clean_answer = extract_answer_from_solution(solution_str) - - # 7. Scoring with task-specific fixes - debug_log_path = "reasoning_gym_debug.log" + extra_info_dict = {} + else: + extra_info_dict = extra_info + + # Get task first + task = extra_info_dict.get("task") + entry = extra_info_dict.get("entry") + + # Handle metadata field if present + if "metadata" in extra_info_dict: + if isinstance(extra_info_dict["metadata"], str): + try: + metadata = json.loads(extra_info_dict["metadata"]) + except Exception: + metadata = {} + elif isinstance(extra_info_dict["metadata"], dict): + metadata = extra_info_dict["metadata"] + + # 2. Try to get from item (fallback - this is rarely used in actual training) + if not task and item and isinstance(item, dict): + task = item.get("ability") + + # 3. Try to get from ground_truth + if not task and isinstance(ground_truth, dict): + task = ground_truth.get("task") + entry = ground_truth + + if not task: + raise ValueError("task must be provided in extra_info, item, or ground_truth dict.") + + # 4. Get scoring function + scorer = reasoning_gym.get_score_answer_fn(task) + + # 5. Get entry + if entry is None: + entry = {"answer": ground_truth} + + # Build metadata field, prioritizing extra_info metadata + if isinstance(entry, dict): + if "metadata" not in entry or not isinstance(entry["metadata"], dict): + entry["metadata"] = {} + if metadata is not None: + entry["metadata"].update(metadata) + if task is not None: + entry["metadata"]["task"] = task + entry["metadata"]["solution_str"] = solution_str + entry["metadata"]["ground_truth"] = ground_truth + if extra_info is not None: + entry["metadata"]["extra_info"] = extra_info + if item is not None: + entry["metadata"]["item"] = item + + # 6. Extract clean answer from solution_str + clean_answer = extract_answer_from_solution(solution_str) + + # 7. Scoring with task-specific fixes + debug_log_path = "reasoning_gym_debug.log" + try: + with open(debug_log_path, "a", encoding="utf-8") as f: + f.write("[DEBUG] solution_str: {}\n".format(solution_str)) + f.write("[DEBUG] clean_answer: {}\n".format(clean_answer)) + f.write("[DEBUG] ground_truth: {}\n".format(ground_truth)) + f.write("[DEBUG] task: {}\n".format(task)) + f.write("[DEBUG] metadata: {}\n".format(json.dumps(entry.get("metadata", {}), ensure_ascii=False, indent=2))) + + # Get raw score from reasoning_gym using clean answer + raw_score = scorer(answer=clean_answer, entry=entry) + + # Apply task-specific corrections for known issues + corrected_score = apply_task_specific_corrections(task, solution_str, ground_truth, raw_score) + + f.write("[DEBUG] raw_score: {}\n".format(raw_score)) + f.write("[DEBUG] corrected_score: {}\n".format(corrected_score)) + + return {"score": float(corrected_score), "acc": float(corrected_score)} + except Exception as e: + with open(debug_log_path, "a", encoding="utf-8") as f: + f.write(f"Error in reasoning gym scoring: {e}\n") + return {"score": 0.0, "acc": 0.0} + try: - with open(debug_log_path, "a", encoding="utf-8") as f: - f.write("[DEBUG] solution_str: {}\n".format(solution_str)) - f.write("[DEBUG] clean_answer: {}\n".format(clean_answer)) - f.write("[DEBUG] ground_truth: {}\n".format(ground_truth)) - f.write("[DEBUG] task: {}\n".format(task)) - f.write("[DEBUG] metadata: {}\n".format(json.dumps(entry.get("metadata", {}), ensure_ascii=False, indent=2))) - - # Get raw score from reasoning_gym using clean answer - raw_score = scorer(answer=clean_answer, entry=entry) - - # Apply task-specific corrections for known issues - corrected_score = apply_task_specific_corrections(task, solution_str, ground_truth, raw_score) - - f.write("[DEBUG] raw_score: {}\n".format(raw_score)) - f.write("[DEBUG] corrected_score: {}\n".format(corrected_score)) - - return {"score": float(corrected_score), "acc": float(corrected_score)} + return _compute_score_with_timeout() + except TimeoutError: + print("Computation timed out in reasoning_gym") + return {"score": 0.0, "acc": 0.0} except Exception as e: - with open(debug_log_path, "a", encoding="utf-8") as f: - f.write(f"Error in reasoning gym scoring: {e}\n") + print(f"Error in compute_score in reasoning_gym: {e}") return {"score": 0.0, "acc": 0.0} diff --git a/verl/utils/reward_score/synlogic/arrow_maze_verifier.py b/verl/utils/reward_score/synlogic/arrow_maze_verifier.py index 14f5977f..d06cfc63 100644 --- a/verl/utils/reward_score/synlogic/arrow_maze_verifier.py +++ b/verl/utils/reward_score/synlogic/arrow_maze_verifier.py @@ -3,6 +3,7 @@ from .verifier import Verifier from .data import Data import re +from verl.utils.py_functional import timeout_limit class ArrowMazeVerifier(Verifier): """ @@ -43,101 +44,112 @@ def verify(self, data: Data, test_solution_str: str) -> bool: @param test_solution_str: 测试答案字符串 (JSON格式的二维数组) @return: 答案是否正确 """ - test_answer_str = self.extract_answer(test_solution_str) - if not test_answer_str: - # print("答案为空,验证失败") - return False - - try: - # 解析测试答案 - test_answer = json.loads(test_answer_str) - - # 获取原始迷宫 - question_grid = data.metadata["maze"] - - # 检查答案是否符合要求 - if not self._verify_grid_size(test_answer, question_grid): - # print("答案网格大小与题目不匹配") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案网格大小与题目不匹配" + '\n') - f.write('-'*32 + '\n') - return False - - if not self._verify_number_positions(test_answer, question_grid): - # print("答案中数字位置或值与题目不匹配") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案中数字位置或值与题目不匹配" + '\n') - f.write('-'*32 + '\n') + @timeout_limit(seconds=60) + def _verify_with_timeout(): + test_answer_str = self.extract_answer(test_solution_str) + if not test_answer_str: + # print("答案为空,验证失败") return False + + try: + # 解析测试答案 + test_answer = json.loads(test_answer_str) - if not self._verify_all_blanks_filled(test_answer, question_grid): - # print("答案中有空格未被填满") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案中有空格未被填满" + '\n') - f.write('-'*32 + '\n') - return False + # 获取原始迷宫 + question_grid = data.metadata["maze"] - if not self._verify_arrow_symbols(test_answer): - # print("答案中包含非法箭头符号") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案中包含非法箭头符号" + '\n') - f.write('-'*32 + '\n') - return False - - if not self._verify_prefilled_arrows(test_answer, question_grid): - # print("答案中预填箭头与题目不一致") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案中预填箭头与题目不一致" + '\n') - f.write('-'*32 + '\n') - return False + # 检查答案是否符合要求 + if not self._verify_grid_size(test_answer, question_grid): + # print("答案网格大小与题目不匹配") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案网格大小与题目不匹配" + '\n') + f.write('-'*32 + '\n') + return False + + if not self._verify_number_positions(test_answer, question_grid): + # print("答案中数字位置或值与题目不匹配") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案中数字位置或值与题目不匹配" + '\n') + f.write('-'*32 + '\n') + return False + + if not self._verify_all_blanks_filled(test_answer, question_grid): + # print("答案中有空格未被填满") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案中有空格未被填满" + '\n') + f.write('-'*32 + '\n') + return False + + if not self._verify_arrow_symbols(test_answer): + # print("答案中包含非法箭头符号") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案中包含非法箭头符号" + '\n') + f.write('-'*32 + '\n') + return False + + if not self._verify_prefilled_arrows(test_answer, question_grid): + # print("答案中预填箭头与题目不一致") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案中预填箭头与题目不一致" + '\n') + f.write('-'*32 + '\n') + return False + + if not self._verify_arrow_rays(test_answer): + # print("答案中存在未被射线覆盖的箭头") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案中存在未被射线覆盖的箭头" + '\n') + f.write('-'*32 + '\n') + return False + + if not self._verify_number_rays(test_answer): + # print("答案中数字的射线箭头串总数不符合要求") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案中数字的射线箭头串总数不符合要求" + '\n') + f.write('-'*32 + '\n') + return False - if not self._verify_arrow_rays(test_answer): - # print("答案中存在未被射线覆盖的箭头") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案中存在未被射线覆盖的箭头" + '\n') - f.write('-'*32 + '\n') - return False + # 所有验证都通过 + # print("验证通过!") + return True - if not self._verify_number_rays(test_answer): - # print("答案中数字的射线箭头串总数不符合要求") + except Exception as e: + # print(f"验证过程中出错: {e}") with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: f.write("test_solution_str: " + test_solution_str + '\n') f.write("test_answer_str: " + test_answer_str + '\n') f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案中数字的射线箭头串总数不符合要求" + '\n') - f.write('-'*32 + '\n') + f.write("验证过程中出错" + str(e) + '\n') + f.write('-'*32 + '\n') return False - - # 所有验证都通过 - # print("验证通过!") - return True - + + try: + return _verify_with_timeout() + except TimeoutError: + print("Verification timed out (ArrowMazeVerifier)") + return False except Exception as e: - # print(f"验证过程中出错: {e}") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("验证过程中出错" + str(e) + '\n') - f.write('-'*32 + '\n') + print(f"Verification error (ArrowMazeVerifier): {e}") return False def _verify_grid_size(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool: diff --git a/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py b/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py index 8edab88f..d5e6b968 100644 --- a/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py +++ b/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py @@ -1,6 +1,7 @@ import re from .data import Data from .verifier import Verifier +from verl.utils.py_functional import timeout_limit class BooleanExpressionsVerifier(Verifier): """ @@ -8,19 +9,22 @@ class BooleanExpressionsVerifier(Verifier): """ def verify(self, data: Data, test_answer: str): try: - test_answer = self.extract_answer(test_answer) - if test_answer is None: - return False - # 提取所有字母(a-z和A-Z) - test_answer_letters = re.findall(r'[a-zA-Z]', test_answer) - ground_truth_letters = re.findall(r'[a-zA-Z]', data.answer) - test_answer_letters = self.lower(test_answer_letters) - ground_truth_letters = self.lower(ground_truth_letters) - # 转换为集合进行比较 - test_set = set(test_answer_letters) - ground_truth_set = set(ground_truth_letters) - - return test_set == ground_truth_set + @timeout_limit(seconds=10) + def _verify_with_timeout(): + test_answer_extracted = self.extract_answer(test_answer) + if test_answer_extracted is None: + return False + # 提取所有字母(a-z和A-Z) + test_answer_letters = re.findall(r'[a-zA-Z]', test_answer_extracted) + ground_truth_letters = re.findall(r'[a-zA-Z]', data.answer) + test_answer_letters = self.lower(test_answer_letters) + ground_truth_letters = self.lower(ground_truth_letters) + # 转换为集合进行比较 + test_set = set(test_answer_letters) + ground_truth_set = set(ground_truth_letters) + + return test_set == ground_truth_set + return _verify_with_timeout() except Exception as e: print("NOTE!!! parse error!!!! (BooleanExpressions)", e) return False diff --git a/verl/utils/reward_score/synlogic/campsite_verifier.py b/verl/utils/reward_score/synlogic/campsite_verifier.py index 4aee1a5b..60203268 100644 --- a/verl/utils/reward_score/synlogic/campsite_verifier.py +++ b/verl/utils/reward_score/synlogic/campsite_verifier.py @@ -3,6 +3,7 @@ import re import ast from typing import List, Set, Tuple, Dict +from verl.utils.py_functional import timeout_limit class CampsiteVerifier(Verifier): @@ -11,35 +12,38 @@ class CampsiteVerifier(Verifier): """ def verify(self, data: Data, test_solution: str): try: - test_answer = self.extract_answer(test_solution) - original_grid = data.metadata["grid"] - row_constraints = data.metadata["row_constraints"] - col_constraints = data.metadata["col_constraints"] - n = data.metadata["n"] - m = data.metadata["m"] - - if not test_answer: - return False - - if len(test_answer) != n or any(len(row) != m for row in test_answer): - return False - - if not self._check_trees_unchanged(original_grid, test_answer): - return False - - if not self._check_row_constraints(test_answer, row_constraints): - return False - - if not self._check_col_constraints(test_answer, col_constraints): - return False - - if not self._check_tents_not_adjacent(test_answer): - return False - - if not self._check_tent_tree_matching(test_answer): - return False - - return True + @timeout_limit(seconds=10) + def _verify_with_timeout(): + test_answer = self.extract_answer(test_solution) + original_grid = data.metadata["grid"] + row_constraints = data.metadata["row_constraints"] + col_constraints = data.metadata["col_constraints"] + n = data.metadata["n"] + m = data.metadata["m"] + + if not test_answer: + return False + + if len(test_answer) != n or any(len(row) != m for row in test_answer): + return False + + if not self._check_trees_unchanged(original_grid, test_answer): + return False + + if not self._check_row_constraints(test_answer, row_constraints): + return False + + if not self._check_col_constraints(test_answer, col_constraints): + return False + + if not self._check_tents_not_adjacent(test_answer): + return False + + if not self._check_tent_tree_matching(test_answer): + return False + + return True + return _verify_with_timeout() except Exception as e: print(f"Verification error (Campsite): {e}") diff --git a/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py index 3bfbdac7..1afb8ff7 100644 --- a/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py +++ b/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py @@ -1,6 +1,7 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import re +from verl.utils.py_functional import timeout_limit class DyckLanguageErrorsVerifier(Verifier): @@ -16,39 +17,42 @@ def verify(self, data: Data, test_answer: str): @return: 回答是否正确的布尔值 """ try: - test_answer = self.extract_answer(test_solution=test_answer) - # 获取正确答案 - if data.metadata["is_valid"]: - correct_answer = "-1" # 合法序列对应-1 - else: - correct_answer = str(data.metadata["first_error_pos"]) - - # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'") - - # 清理和标准化答案 - test_answer = test_answer.strip() - - # 检查-1答案(合法序列) - if correct_answer == "-1": - # 如果正确答案是-1(合法序列),只接受-1作为回答 - if test_answer == "-1": - is_correct = True + @timeout_limit(seconds=10) + def _verify_with_timeout(): + test_answer_extracted = self.extract_answer(test_solution=test_answer) + # 获取正确答案 + if data.metadata["is_valid"]: + correct_answer = "-1" # 合法序列对应-1 else: - is_correct = False - else: - # 正确答案是位置数字,需要验证模型回答也是相同数字 - try: - is_correct = (int(test_answer) == int(correct_answer)) - except (ValueError, TypeError): - # 如果模型回答不是有效数字,验证失败 - is_correct = False - - # if is_correct: - # print("验证结果: 正确") - # else: - # print("验证结果: 错误") + correct_answer = str(data.metadata["first_error_pos"]) + + # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'") + + # 清理和标准化答案 + test_answer_clean = test_answer_extracted.strip() + + # 检查-1答案(合法序列) + if correct_answer == "-1": + # 如果正确答案是-1(合法序列),只接受-1作为回答 + if test_answer_clean == "-1": + is_correct = True + else: + is_correct = False + else: + # 正确答案是位置数字,需要验证模型回答也是相同数字 + try: + is_correct = (int(test_answer_clean) == int(correct_answer)) + except (ValueError, TypeError): + # 如果模型回答不是有效数字,验证失败 + is_correct = False - return is_correct + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + + return is_correct + return _verify_with_timeout() except Exception as e: print(f"Verification error (DyckLanguageErrors): {e}") diff --git a/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py index 03f2b95f..17c68087 100644 --- a/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py +++ b/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py @@ -1,6 +1,7 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import re +from verl.utils.py_functional import timeout_limit class DyckLanguageReasoningErrorsVerifier(Verifier): @@ -16,39 +17,42 @@ def verify(self, data: Data, test_answer: str): @return: 回答是否正确的布尔值 """ try: - test_answer = self.extract_answer(test_solution=test_answer) - # 获取元数据中的正确答案 - correct_indices = data.metadata["error_indices"] - # 格式化为正确的答案字符串格式 - expected_answer = self._format_answer(correct_indices) - - # print(f"验证: 模型答案='{test_answer}', 正确答案='{expected_answer}'") - - # 检查不明确的答案 - if "不确定" in test_answer or "不知道" in test_answer or "unclear" in test_answer.lower(): - # print("验证结果: 错误") - return False - - # 清理模型答案,允许一定的格式变化 - cleaned_test_answer = self._standardize_answer(test_answer) - - if not correct_indices and (cleaned_test_answer == "" or cleaned_test_answer.lower() in ["无问题", "no", "无错误", "no error", "no errors", "no mistakes", "all correct"]): - # 如果没有错误,且模型回答是空字符串或表示无问题,则正确 - is_correct = True - else: - # 将两个答案转换为数字集合进行比较 - test_error_indices = self._extract_error_indices(cleaned_test_answer) - expected_error_indices = set(correct_indices) + @timeout_limit(seconds=10) + def _verify_with_timeout(): + test_answer_extracted = self.extract_answer(test_solution=test_answer) + # 获取元数据中的正确答案 + correct_indices = data.metadata["error_indices"] + # 格式化为正确的答案字符串格式 + expected_answer = self._format_answer(correct_indices) - # 检查两个集合是否相同 - is_correct = test_error_indices == expected_error_indices - - # if is_correct: - # print("验证结果: 正确") - # else: - # print("验证结果: 错误") + # print(f"验证: 模型答案='{test_answer}', 正确答案='{expected_answer}'") + + # 检查不明确的答案 + if "不确定" in test_answer_extracted or "不知道" in test_answer_extracted or "unclear" in test_answer_extracted.lower(): + # print("验证结果: 错误") + return False + + # 清理模型答案,允许一定的格式变化 + cleaned_test_answer = self._standardize_answer(test_answer_extracted) + + if not correct_indices and (cleaned_test_answer == "" or cleaned_test_answer.lower() in ["无问题", "no", "无错误", "no error", "no errors", "no mistakes", "all correct"]): + # 如果没有错误,且模型回答是空字符串或表示无问题,则正确 + is_correct = True + else: + # 将两个答案转换为数字集合进行比较 + test_error_indices = self._extract_error_indices(cleaned_test_answer) + expected_error_indices = set(correct_indices) + + # 检查两个集合是否相同 + is_correct = test_error_indices == expected_error_indices - return is_correct + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + + return is_correct + return _verify_with_timeout() except Exception as e: print(f"Verification error (DyckLanguageReasoningErrors): {e}") diff --git a/verl/utils/reward_score/synlogic/dyck_language_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_verifier.py index e986f66c..20b8125c 100644 --- a/verl/utils/reward_score/synlogic/dyck_language_verifier.py +++ b/verl/utils/reward_score/synlogic/dyck_language_verifier.py @@ -1,6 +1,7 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import re +from verl.utils.py_functional import timeout_limit class DyckLanguageVerifier(Verifier): @@ -16,23 +17,26 @@ def verify(self, data: Data, test_answer: str) -> bool: @return: 回答是否正确的布尔值 """ try: - # 获取元数据中的完整序列 - full_sequence = data.metadata["full_sequence"] - - # print(f"验证: 模型答案='{test_answer}', 完整序列='{full_sequence}'") - - # 从模型回答中提取答案 - extracted_answer = self.extract_answer(test_answer) - - # 检查答案是否完全匹配 - is_correct = (extracted_answer == full_sequence) - - # if is_correct: - # print("验证结果: 正确") - # else: - # print("验证结果: 错误") + @timeout_limit(seconds=10) + def _verify_with_timeout(): + # 获取元数据中的完整序列 + full_sequence = data.metadata["full_sequence"] + + # print(f"验证: 模型答案='{test_answer}', 完整序列='{full_sequence}'") + + # 从模型回答中提取答案 + extracted_answer = self.extract_answer(test_answer) + + # 检查答案是否完全匹配 + is_correct = (extracted_answer == full_sequence) - return is_correct + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + + return is_correct + return _verify_with_timeout() except Exception as e: print(f"Verification error (DyckLanguage): {e}") diff --git a/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py b/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py index b4c8bdb2..7993c506 100644 --- a/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py +++ b/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py @@ -1,6 +1,7 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import re +from verl.utils.py_functional import timeout_limit class BuggyTableVerifier(Verifier): """ @@ -26,19 +27,26 @@ def verify(self, data: Data, test_answer: str) -> bool: @param test_answer: The answer provided by the LLM to verify @return: bool indicating whether the answer is correct """ - # Extract the expected answer from the Data object - expected_answer = data.answer if data and hasattr(data, 'answer') else "" - - # For empty strings, compare directly - if not expected_answer and not test_answer: - return True - - # Extract and normalize both answers - normalized_expected = self._extract_answer(expected_answer) - normalized_test = self._extract_answer(test_answer) - - # Direct comparison of normalized answers - return normalized_expected == normalized_test + try: + @timeout_limit(seconds=10) + def _verify_with_timeout(): + # Extract the expected answer from the Data object + expected_answer = data.answer if data and hasattr(data, 'answer') else "" + + # For empty strings, compare directly + if not expected_answer and not test_answer: + return True + + # Extract and normalize both answers + normalized_expected = self._extract_answer(expected_answer) + normalized_test = self._extract_answer(test_answer) + + # Direct comparison of normalized answers + return normalized_expected == normalized_test + return _verify_with_timeout() + except Exception as e: + # print(f"Verification error (BuggyTable): {e}") + return False def _is_raw_numeric_answer(self, value: str) -> bool: """ diff --git a/verl/utils/reward_score/synlogic/goods_exchange_verifier.py b/verl/utils/reward_score/synlogic/goods_exchange_verifier.py index 922e9268..fbe1eff2 100644 --- a/verl/utils/reward_score/synlogic/goods_exchange_verifier.py +++ b/verl/utils/reward_score/synlogic/goods_exchange_verifier.py @@ -1,6 +1,7 @@ import re from .data import Data from .verifier import Verifier +from verl.utils.py_functional import timeout_limit class GoodsExchangeVerifier(Verifier): """ @@ -15,28 +16,31 @@ def verify(self, data: Data, test_solution: str): @return: 回答是否正确的布尔值 """ try: - test_answer = self.extract_answer(test_solution) - # 获取元数据中的正确答案 - correct_answer = data.metadata["owns_after"] - - # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'") - - # 解析模型答案 - model_ownership = self._parse_answer(test_answer) - # 解析正确答案 - correct_ownership = self._parse_answer(correct_answer) - - # 比较两个答案是否完全一致 - is_correct = self._compare_answers(model_ownership, correct_ownership) - - # if is_correct: - # print("验证结果: 正确") - # else: - # print("验证结果: 错误") - # # 打印详细的不匹配信息 - # self._print_difference(model_ownership, correct_ownership) + @timeout_limit(seconds=10) + def _verify_with_timeout(): + test_answer = self.extract_answer(test_solution) + # 获取元数据中的正确答案 + correct_answer = data.metadata["owns_after"] - return is_correct + # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'") + + # 解析模型答案 + model_ownership = self._parse_answer(test_answer) + # 解析正确答案 + correct_ownership = self._parse_answer(correct_answer) + + # 比较两个答案是否完全一致 + is_correct = self._compare_answers(model_ownership, correct_ownership) + + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + # # 打印详细的不匹配信息 + # self._print_difference(model_ownership, correct_ownership) + + return is_correct + return _verify_with_timeout() except Exception as e: print(f"Verification error (GoodsExchange): {e}") diff --git a/verl/utils/reward_score/synlogic/math_path_verifier.py b/verl/utils/reward_score/synlogic/math_path_verifier.py index d0df1c4a..5ebe1d81 100644 --- a/verl/utils/reward_score/synlogic/math_path_verifier.py +++ b/verl/utils/reward_score/synlogic/math_path_verifier.py @@ -3,6 +3,7 @@ import numpy as np from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +from verl.utils.py_functional import timeout_limit class MathPathVerifier(Verifier): @@ -18,59 +19,67 @@ def verify(self, data: Data, test_answer: str): @return: 回答是否正确的布尔值 """ try: - test_answer = self.extract_answer(test_solution=test_answer) - except Exception as e: - print(f"NOTE!!! parse error!!!! (MathPath): {e}") - return False + @timeout_limit(seconds=10) + def _verify_with_timeout(): + try: + test_answer_extracted = self.extract_answer(test_solution=test_answer) + except Exception as e: + # print(f"NOTE!!! parse error!!!! (MathPath): {e}") + return False - try: - # 解析元数据 - metadata = data.metadata - ref_expr = metadata["ref_expr"] - query_expr = metadata["query_expr"] + try: + # 解析元数据 + metadata = data.metadata + ref_expr = metadata["ref_expr"] + query_expr = metadata["query_expr"] - # 验证数字是否被篡改,数字是否在0-9之间。 - test_tmp = test_answer.replace(' ', '').strip() - query_tmp = query_expr.replace(' ', '').strip() - ref_tmp = ref_expr.replace(' ', '').strip() - query_nums = [x for x in query_tmp if '0'<=x<='9' or x=='?'] - test_nums = [x for x in test_tmp if '0'<=x<='9'] - if len(query_nums)!=len(test_nums): - # print(f"所填数字数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") - return False - else: - for ind, x in enumerate(query_nums): - if x=='?': - continue - if x!=test_nums[ind]: - # print(f"表达式数字被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + # 验证数字是否被篡改,数字是否在0-9之间。 + test_tmp = test_answer_extracted.replace(' ', '').strip() + query_tmp = query_expr.replace(' ', '').strip() + ref_tmp = ref_expr.replace(' ', '').strip() + query_nums = [x for x in query_tmp if '0'<=x<='9' or x=='?'] + test_nums = [x for x in test_tmp if '0'<=x<='9'] + if len(query_nums)!=len(test_nums): + # print(f"所填数字数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") return False + else: + for ind, x in enumerate(query_nums): + if x=='?': + continue + if x!=test_nums[ind]: + # print(f"表达式数字被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + + query_symbols = [x for x in query_tmp if x in ['+', '-', '*', '/', '%']] + test_symbols = [x for x in test_tmp if x in ['+', '-', '*', '/', '%']] + if len(query_symbols)!=len(test_symbols): + # print(f"表达式运算符号数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + else: + for ind, x in enumerate(query_symbols): + if x!=test_symbols[ind]: + # print(f"表达式运算符号被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False - query_symbols = [x for x in query_tmp if x in ['+', '-', '*', '/', '%']] - test_symbols = [x for x in test_tmp if x in ['+', '-', '*', '/', '%']] - if len(query_symbols)!=len(test_symbols): - # print(f"表达式运算符号数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") - return False - else: - for ind, x in enumerate(query_symbols): - if x!=test_symbols[ind]: - # print(f"表达式运算符号被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + # 验证回答中的等式是否成立 + try: + tmp = test_tmp.replace('=', '==') + if not eval(tmp): + # print(f"等式不成立!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + except: + # print(f"运算表达式错误!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") return False - - # 验证回答中的等式是否成立 - try: - tmp = test_tmp.replace('=', '==') - if not eval(tmp): - # print(f"等式不成立!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + + + # 所有检查都通过 + # print("验证结果: 正确") + return True + + except Exception as e: + print(f"Verification error (MathPath): {e}") return False - except: - # print(f"运算表达式错误!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") - return False - - - # 所有检查都通过 - # print("验证结果: 正确") - return True + return _verify_with_timeout() except Exception as e: print(f"Verification error (MathPath): {e}") diff --git a/verl/utils/reward_score/synlogic/minesweeper_verifier.py b/verl/utils/reward_score/synlogic/minesweeper_verifier.py index 1b279140..f0f5ae79 100644 --- a/verl/utils/reward_score/synlogic/minesweeper_verifier.py +++ b/verl/utils/reward_score/synlogic/minesweeper_verifier.py @@ -3,6 +3,7 @@ import re import json from typing import List, Tuple +from verl.utils.py_functional import timeout_limit class MinesweeperVerifier(Verifier): @@ -12,17 +13,20 @@ class MinesweeperVerifier(Verifier): """ def verify(self, data: Data, test_solution: str, **kwargs): try: - # 从解答中提取地雷坐标 - predicted_mines = self.extract_answer(test_solution) - - # 从metadata中获取确定性地雷坐标 - expected_mines = data.metadata["current_mines"] - - # 验证提取的坐标是否正确 - if set(tuple(mine) for mine in predicted_mines) == set(tuple(mine) for mine in expected_mines): - return True - - return False + @timeout_limit(seconds=10) + def _verify_with_timeout(): + # 从解答中提取地雷坐标 + predicted_mines = self.extract_answer(test_solution) + + # 从metadata中获取确定性地雷坐标 + expected_mines = data.metadata["current_mines"] + + # 验证提取的坐标是否正确 + if set(tuple(mine) for mine in predicted_mines) == set(tuple(mine) for mine in expected_mines): + return True + + return False + return _verify_with_timeout() except Exception as e: # 如果验证过程中发生任何错误,返回False diff --git a/verl/utils/reward_score/synlogic/norinori_verifier.py b/verl/utils/reward_score/synlogic/norinori_verifier.py index 95cc98ed..c817198d 100644 --- a/verl/utils/reward_score/synlogic/norinori_verifier.py +++ b/verl/utils/reward_score/synlogic/norinori_verifier.py @@ -2,6 +2,7 @@ from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import re from collections import defaultdict +from verl.utils.py_functional import timeout_limit class NorinoriVerifier(Verifier): """ @@ -24,54 +25,57 @@ def verify(self, data: Data, test_solution: str): bool -- 答案是否正确 """ try: - # 从游戏数据中获取区域网格 - region_grid = data.metadata["region_grid"] - n = len(region_grid) - - # 解析答案 - dominoes = self._parse_answer(test_solution) - if dominoes is None: - return False - - # 检查多米诺形状 - if not self._check_domino_shapes(dominoes): - return False - - # 创建覆盖网格 - covered = [[False for _ in range(n)] for _ in range(n)] - for domino in dominoes: - for i, j in domino: - # 转换为0-indexed - i -= 1 - j -= 1 - if i < 0 or i >= n or j < 0 or j >= n: - return False # 坐标超出范围 - if covered[i][j]: - return False # 格子被多次覆盖 - covered[i][j] = True - - # 检查多米诺之间是否相邻 - if not self._check_domino_adjacency(dominoes, n): - return False - - # 检查每个区域是否恰好有两个格子被覆盖 - region_coverage = defaultdict(int) - for i in range(n): - for j in range(n): - if covered[i][j] and region_grid[i][j] != "X": - region_coverage[region_grid[i][j]] += 1 - - for region, count in region_coverage.items(): - if count != 2: + @timeout_limit(seconds=10) + def _verify_with_timeout(): + # 从游戏数据中获取区域网格 + region_grid = data.metadata["region_grid"] + n = len(region_grid) + + # 解析答案 + dominoes = self._parse_answer(test_solution) + if dominoes is None: return False - - # 检查所有阴影格子是否被覆盖 - for i in range(n): - for j in range(n): - if region_grid[i][j] == "X" and not covered[i][j]: + + # 检查多米诺形状 + if not self._check_domino_shapes(dominoes): + return False + + # 创建覆盖网格 + covered = [[False for _ in range(n)] for _ in range(n)] + for domino in dominoes: + for i, j in domino: + # 转换为0-indexed + i -= 1 + j -= 1 + if i < 0 or i >= n or j < 0 or j >= n: + return False # 坐标超出范围 + if covered[i][j]: + return False # 格子被多次覆盖 + covered[i][j] = True + + # 检查多米诺之间是否相邻 + if not self._check_domino_adjacency(dominoes, n): + return False + + # 检查每个区域是否恰好有两个格子被覆盖 + region_coverage = defaultdict(int) + for i in range(n): + for j in range(n): + if covered[i][j] and region_grid[i][j] != "X": + region_coverage[region_grid[i][j]] += 1 + + for region, count in region_coverage.items(): + if count != 2: return False - return True + # 检查所有阴影格子是否被覆盖 + for i in range(n): + for j in range(n): + if region_grid[i][j] == "X" and not covered[i][j]: + return False + + return True + return _verify_with_timeout() except Exception as e: print(f"Verification error (Norinori): {e}") return False diff --git a/verl/utils/reward_score/synlogic/number_wall_verifier.py b/verl/utils/reward_score/synlogic/number_wall_verifier.py index 541c0720..aefbedae 100644 --- a/verl/utils/reward_score/synlogic/number_wall_verifier.py +++ b/verl/utils/reward_score/synlogic/number_wall_verifier.py @@ -3,6 +3,7 @@ import re import json from collections import deque +from verl.utils.py_functional import timeout_limit class NumberWallVerifier(Verifier): """ @@ -11,53 +12,56 @@ class NumberWallVerifier(Verifier): """ def verify(self, data: Data, test_solution: str, **kwargs): try: - # 提取答案网格 - solution_grid = self.extract_answer(test_solution) - if not solution_grid: - # print("Failed to extract solution grid") - return False - - # 提取元数据 - original_grid = data.metadata["grid"] - n = data.metadata["n"] - - # 检查网格尺寸 - if len(solution_grid) != n: - # print(f"Solution grid has incorrect number of rows: {len(solution_grid)} != {n}") - return False + @timeout_limit(seconds=10) + def _verify_with_timeout(): + # 提取答案网格 + solution_grid = self.extract_answer(test_solution) + if not solution_grid: + # print("Failed to extract solution grid") + return False + + # 提取元数据 + original_grid = data.metadata["grid"] + n = data.metadata["n"] - for row in solution_grid: - if len(row) != n: - # print(f"Solution grid has incorrect number of columns: {len(row)} != {n}") + # 检查网格尺寸 + if len(solution_grid) != n: + # print(f"Solution grid has incorrect number of rows: {len(solution_grid)} != {n}") return False - # 检查每个单元格只包含数字、"X"或"A" - for cell in row: - if not (isinstance(cell, int) or cell in ["X", "A"]): - # print(f"Invalid cell content: {cell}") + for row in solution_grid: + if len(row) != n: + # print(f"Solution grid has incorrect number of columns: {len(row)} != {n}") return False - - # 检查原始数字是否保留 - if not self._check_original_numbers(original_grid, solution_grid): - # print("Original numbers not preserved") - return False - - # 检查墙壁布局是否有效(没有2×2或更大的连续墙块) - if not self._check_wall_layout(solution_grid): - # print("Invalid wall layout (2x2 or larger continuous wall blocks found)") - return False - - # 检查岛屿划分是否有效 - if not self._check_islands(solution_grid): - # print("Invalid island division") - return False - - # 检查是否有斜线边 - if not self._check_diagonal_borders(solution_grid): - # print("Invalid solution: islands have diagonal borders") - return False + + # 检查每个单元格只包含数字、"X"或"A" + for cell in row: + if not (isinstance(cell, int) or cell in ["X", "A"]): + # print(f"Invalid cell content: {cell}") + return False - return True + # 检查原始数字是否保留 + if not self._check_original_numbers(original_grid, solution_grid): + # print("Original numbers not preserved") + return False + + # 检查墙壁布局是否有效(没有2×2或更大的连续墙块) + if not self._check_wall_layout(solution_grid): + # print("Invalid wall layout (2x2 or larger continuous wall blocks found)") + return False + + # 检查岛屿划分是否有效 + if not self._check_islands(solution_grid): + # print("Invalid island division") + return False + + # 检查是否有斜线边 + if not self._check_diagonal_borders(solution_grid): + # print("Invalid solution: islands have diagonal borders") + return False + + return True + return _verify_with_timeout() except Exception as e: # 如果验证过程中发生任何错误,返回False diff --git a/verl/utils/reward_score/synlogic/numbrix_verifier.py b/verl/utils/reward_score/synlogic/numbrix_verifier.py index f142bc7b..d7e9ea21 100644 --- a/verl/utils/reward_score/synlogic/numbrix_verifier.py +++ b/verl/utils/reward_score/synlogic/numbrix_verifier.py @@ -3,6 +3,7 @@ import re import ast import numpy as np +from verl.utils.py_functional import timeout_limit class NumbrixVerifier(Verifier): """ @@ -11,54 +12,57 @@ class NumbrixVerifier(Verifier): """ def verify(self, data: Data, test_solution: str): try: - # 提取答案网格 - test_grid = self.extract_answer(test_solution) - if not test_grid: - return False - - # 获取原始谜题和网格大小 - original_grid = data.metadata["grid"] - n = len(original_grid) - n_squared = n * n - - # 检查网格大小是否正确 - if len(test_grid) != n or any(len(row) != n for row in test_grid): - return False - - # 检查是否包含所有数字 1 到 n² - flattened_grid = [cell for row in test_grid for cell in row] - if sorted(flattened_grid) != list(range(1, n_squared + 1)): - return False - - # 检查是否保留了原始提示数字 - for i in range(n): - for j in range(n): - if original_grid[i][j] != "X" and test_grid[i][j] != original_grid[i][j]: - return False - - # 检查连续数字是否正交相邻 - for num in range(1, n_squared): - # 找到当前数字的位置 - current_pos = None - next_pos = None - for i in range(n): - for j in range(n): - if test_grid[i][j] == num: - current_pos = (i, j) - elif test_grid[i][j] == num + 1: - next_pos = (i, j) + @timeout_limit(seconds=10) + def _verify_with_timeout(): + # 提取答案网格 + test_grid = self.extract_answer(test_solution) + if not test_grid: + return False - if current_pos is None or next_pos is None: + # 获取原始谜题和网格大小 + original_grid = data.metadata["grid"] + n = len(original_grid) + n_squared = n * n + + # 检查网格大小是否正确 + if len(test_grid) != n or any(len(row) != n for row in test_grid): return False - # 检查是否正交相邻(曼哈顿距离为1) - i1, j1 = current_pos - i2, j2 = next_pos - manhattan_distance = abs(i1 - i2) + abs(j1 - j2) - if manhattan_distance != 1: + # 检查是否包含所有数字 1 到 n² + flattened_grid = [cell for row in test_grid for cell in row] + if sorted(flattened_grid) != list(range(1, n_squared + 1)): return False - - return True + + # 检查是否保留了原始提示数字 + for i in range(n): + for j in range(n): + if original_grid[i][j] != "X" and test_grid[i][j] != original_grid[i][j]: + return False + + # 检查连续数字是否正交相邻 + for num in range(1, n_squared): + # 找到当前数字的位置 + current_pos = None + next_pos = None + for i in range(n): + for j in range(n): + if test_grid[i][j] == num: + current_pos = (i, j) + elif test_grid[i][j] == num + 1: + next_pos = (i, j) + + if current_pos is None or next_pos is None: + return False + + # 检查是否正交相邻(曼哈顿距离为1) + i1, j1 = current_pos + i2, j2 = next_pos + manhattan_distance = abs(i1 - i2) + abs(j1 - j2) + if manhattan_distance != 1: + return False + + return True + return _verify_with_timeout() except Exception as e: print(f"Verification error (Numbrix): {e}") return False diff --git a/verl/utils/reward_score/synlogic/object_counting_verifier.py b/verl/utils/reward_score/synlogic/object_counting_verifier.py index e63e4fd9..16c1d831 100644 --- a/verl/utils/reward_score/synlogic/object_counting_verifier.py +++ b/verl/utils/reward_score/synlogic/object_counting_verifier.py @@ -1,6 +1,7 @@ import re from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +from verl.utils.py_functional import timeout_limit class ObjectCountingVerifier(Verifier): @@ -9,17 +10,20 @@ class ObjectCountingVerifier(Verifier): """ def verify(self, data: Data, test_answer: str): try: - ground_truth = int(data.answer) - parsed_answer = self.extract_answer(test_answer) - with open("solution_str_OC.txt", "a") as f: - f.write("data.answer: " + data.answer + '\n') - f.write("test_answer: " + test_answer + '\n') - f.write("parsed_answer" + parsed_answer + '\n') - f.write('-'*32 + '\n') - - if parsed_answer is None: - return False - return int(parsed_answer) == ground_truth + @timeout_limit(seconds=10) + def _verify_with_timeout(): + ground_truth = int(data.answer) + parsed_answer = self.extract_answer(test_answer) + # with open("solution_str_OC.txt", "a") as f: + # f.write("data.answer: " + data.answer + '\n') + # f.write("test_answer: " + test_answer + '\n') + # f.write("parsed_answer" + parsed_answer + '\n') + # f.write('-'*32 + '\n') + + if parsed_answer is None: + return False + return int(parsed_answer) == ground_truth + return _verify_with_timeout() except Exception as e: print(f"NOTE!!! parse error!!!! (ObjectCounting): {e}") diff --git a/verl/utils/reward_score/synlogic/object_properties_verifier.py b/verl/utils/reward_score/synlogic/object_properties_verifier.py index f8ad06cc..ca6a1f97 100644 --- a/verl/utils/reward_score/synlogic/object_properties_verifier.py +++ b/verl/utils/reward_score/synlogic/object_properties_verifier.py @@ -1,6 +1,7 @@ import re from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +from verl.utils.py_functional import timeout_limit class ObjectPropertiesVerifier(Verifier): @@ -9,12 +10,17 @@ class ObjectPropertiesVerifier(Verifier): """ def verify(self, data: Data, test_answer: str): try: - ground_truth = int(data.answer) - parsed_answer = int(self.extract_answer(test_answer)) - - if parsed_answer is None: - return False - return int(parsed_answer) == ground_truth + @timeout_limit(seconds=10) + def _verify_with_timeout(): + ground_truth = int(data.answer) + parsed_answer_str = self.extract_answer(test_answer) + + if parsed_answer_str is None: + return False + + parsed_answer = int(parsed_answer_str) + return int(parsed_answer) == ground_truth + return _verify_with_timeout() except Exception as e: print(f"NOTE!!! parse error!!!! (ObjectProperties): {e}") diff --git a/verl/utils/reward_score/synlogic/operation_verifier.py b/verl/utils/reward_score/synlogic/operation_verifier.py index 7f730ce8..26a795fe 100644 --- a/verl/utils/reward_score/synlogic/operation_verifier.py +++ b/verl/utils/reward_score/synlogic/operation_verifier.py @@ -2,6 +2,7 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import math_verify +from verl.utils.py_functional import timeout_limit class OperationVerifier(Verifier): @@ -10,12 +11,19 @@ class OperationVerifier(Verifier): """ def verify(self, data: Data, test_answer: str): try: - ground_truth = math_verify.parse(data.answer) - parsed_answer = math_verify.parse(test_answer) - - if parsed_answer is None: - return False - return math_verify.verify(parsed_answer, ground_truth) + @timeout_limit(seconds=20) + def _verify_with_timeout(): + ground_truth = math_verify.parse(data.answer, parsing_timeout=None) + parsed_answer = math_verify.parse(test_answer, parsing_timeout=None) + + if parsed_answer is None: + return False + return math_verify.verify(parsed_answer, ground_truth) + + return _verify_with_timeout() + except TimeoutError: + print("Parsing/Verification timed out (OperationVerifier)") + return False except Exception as e: print(f"NOTE!!! parse error!!!! (OperationVerifier): {e}") return False diff --git a/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py b/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py index 524f5859..bfc6b094 100644 --- a/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py +++ b/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py @@ -3,6 +3,7 @@ import re import json import ast +from verl.utils.py_functional import timeout_limit class SkyscraperPuzzleVerifier(Verifier): @@ -18,91 +19,94 @@ def verify(self, data: Data, test_solution: str): @return: 回答是否正确的布尔值 """ try: - # 获取游戏元数据 - metadata = data.metadata - n = metadata['n'] - top = metadata['top'] - bottom = metadata['bottom'] - left = metadata['left'] - right = metadata['right'] + @timeout_limit(seconds=10) + def _verify_with_timeout(): + # 获取游戏元数据 + metadata = data.metadata + n = metadata['n'] + top = metadata['top'] + bottom = metadata['bottom'] + left = metadata['left'] + right = metadata['right'] - self.n = n - test_answer = self.extract_answer(test_solution) - - # print(f"验证: 游戏规模 {n}×{n}") - # print(f"上方提示: {top}") - # print(f"下方提示: {bottom}") - # print(f"左侧提示: {left}") - # print(f"右侧提示: {right}") - - # 使用提取好的网格数据 - grid = test_answer - - # 检查网格是否是字符串,如果是,说明提取失败 - if isinstance(grid, str): - # print("无法提取有效网格") - return False + self.n = n + test_answer = self.extract_answer(test_solution) - print("提取的网格:") - for row in grid: - print(row) - - # 检查网格规模 - if len(grid) != n or any(len(row) != n for row in grid): - # print(f"网格规模不正确,应为 {n}×{n}") - return False - - # 检查数字范围 (1 到 n) - for i in range(n): - for j in range(n): - if not isinstance(grid[i][j], int) or grid[i][j] < 1 or grid[i][j] > n: - # print(f"位置 ({i+1},{j+1}) 的值 {grid[i][j]} 不在有效范围内 (1-{n})") - return False - - # 检查每行唯一性 - for i in range(n): - if len(set(grid[i])) != n: - # print(f"第 {i+1} 行包含重复数字") - return False - - # 检查每列唯一性 - for j in range(n): - column = [grid[i][j] for i in range(n)] - if len(set(column)) != n: - # print(f"第 {j+1} 列包含重复数字") - return False - - # 检查从上方观察 - for j in range(n): - visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n)]) - if visible_count != top[j]: - # print(f"从上方看第 {j+1} 列可见楼数为 {visible_count},应为 {top[j]}") - return False - - # 检查从下方观察 - for j in range(n): - visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n-1, -1, -1)]) - if visible_count != bottom[j]: - # print(f"从下方看第 {j+1} 列可见楼数为 {visible_count},应为 {bottom[j]}") - return False - - # 检查从左侧观察 - for i in range(n): - visible_count = self._count_visible_skyscrapers(grid[i]) - if visible_count != left[i]: - # print(f"从左侧看第 {i+1} 行可见楼数为 {visible_count},应为 {left[i]}") + # print(f"验证: 游戏规模 {n}×{n}") + # print(f"上方提示: {top}") + # print(f"下方提示: {bottom}") + # print(f"左侧提示: {left}") + # print(f"右侧提示: {right}") + + # 使用提取好的网格数据 + grid = test_answer + + # 检查网格是否是字符串,如果是,说明提取失败 + if isinstance(grid, str): + # print("无法提取有效网格") return False - - # 检查从右侧观察 - for i in range(n): - visible_count = self._count_visible_skyscrapers(grid[i][::-1]) - if visible_count != right[i]: - # print(f"从右侧看第 {i+1} 行可见楼数为 {visible_count},应为 {right[i]}") + + # print("提取的网格:") + # for row in grid: + # print(row) + + # 检查网格规模 + if len(grid) != n or any(len(row) != n for row in grid): + # print(f"网格规模不正确,应为 {n}×{n}") return False - - # 所有检查通过 - # print("所有验证规则通过!") - return True + + # 检查数字范围 (1 到 n) + for i in range(n): + for j in range(n): + if not isinstance(grid[i][j], int) or grid[i][j] < 1 or grid[i][j] > n: + # print(f"位置 ({i+1},{j+1}) 的值 {grid[i][j]} 不在有效范围内 (1-{n})") + return False + + # 检查每行唯一性 + for i in range(n): + if len(set(grid[i])) != n: + # print(f"第 {i+1} 行包含重复数字") + return False + + # 检查每列唯一性 + for j in range(n): + column = [grid[i][j] for i in range(n)] + if len(set(column)) != n: + # print(f"第 {j+1} 列包含重复数字") + return False + + # 检查从上方观察 + for j in range(n): + visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n)]) + if visible_count != top[j]: + # print(f"从上方看第 {j+1} 列可见楼数为 {visible_count},应为 {top[j]}") + return False + + # 检查从下方观察 + for j in range(n): + visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n-1, -1, -1)]) + if visible_count != bottom[j]: + # print(f"从下方看第 {j+1} 列可见楼数为 {visible_count},应为 {bottom[j]}") + return False + + # 检查从左侧观察 + for i in range(n): + visible_count = self._count_visible_skyscrapers(grid[i]) + if visible_count != left[i]: + # print(f"从左侧看第 {i+1} 行可见楼数为 {visible_count},应为 {left[i]}") + return False + + # 检查从右侧观察 + for i in range(n): + visible_count = self._count_visible_skyscrapers(grid[i][::-1]) + if visible_count != right[i]: + # print(f"从右侧看第 {i+1} 行可见楼数为 {visible_count},应为 {right[i]}") + return False + + # 所有检查通过 + # print("所有验证规则通过!") + return True + return _verify_with_timeout() except Exception as e: print(f"Verification error (SkyscraperPuzzle): {e}") diff --git a/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py b/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py index abc165d5..bd41bc8d 100644 --- a/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py +++ b/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py @@ -2,20 +2,28 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import math_verify +from verl.utils.py_functional import timeout_limit class SpaceReasoningTreeVerifier(Verifier): """ 验证器用于空间推理树游戏的答案是否正确 """ def verify(self, data: Data, test_answer: str): - test_answer = self.extract_answer(test_answer) - if test_answer is None: + try: + @timeout_limit(seconds=10) + def _verify_with_timeout(): + test_answer_extracted = self.extract_answer(test_answer) + if test_answer_extracted is None: + return False + test_answer_normalized = test_answer_extracted.replace(",", ",").replace(" ", "") + ground_truth = data.answer.replace(",", ",").replace(" ", "") + test_set = set(test_answer_normalized.split(",")) + ground_truth_set = set(ground_truth.split(",")) + return test_set == ground_truth_set + return _verify_with_timeout() + except Exception as e: + print(f"Verification error (SpaceReasoningTree): {e}") return False - test_answer = test_answer.replace(",", ",").replace(" ", "") - ground_truth = data.answer.replace(",", ",").replace(" ", "") - test_set = set(test_answer.split(",")) - ground_truth_set = set(ground_truth.split(",")) - return test_set == ground_truth_set def extract_answer(self, answer_str): # 先找到最后一个\boxed{的位置 diff --git a/verl/utils/reward_score/synlogic/space_reasoning_verifier.py b/verl/utils/reward_score/synlogic/space_reasoning_verifier.py index 249f2dc0..ed51d2a9 100644 --- a/verl/utils/reward_score/synlogic/space_reasoning_verifier.py +++ b/verl/utils/reward_score/synlogic/space_reasoning_verifier.py @@ -2,6 +2,7 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import math_verify +from verl.utils.py_functional import timeout_limit class SpaceReasoningVerifier(Verifier): @@ -9,10 +10,17 @@ class SpaceReasoningVerifier(Verifier): 验证器用于空间推理游戏的答案是否正确 """ def verify(self, data: Data, test_answer: str): - test_answer = self.extract_answer(test_answer) - if test_answer is None: + try: + @timeout_limit(seconds=10) + def _verify_with_timeout(): + test_answer_extracted = self.extract_answer(test_answer) + if test_answer_extracted is None: + return False + return test_answer_extracted.lower() == data.answer.lower() + return _verify_with_timeout() + except Exception as e: + print(f"Verification error (SpaceReasoning): {e}") return False - return test_answer.lower() == data.answer.lower() def extract_answer(self, answer_str): # 先找到最后一个\boxed{的位置 diff --git a/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py b/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py index 98715e19..15f2695e 100644 --- a/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py +++ b/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py @@ -3,8 +3,7 @@ import re import json import ast - -import re +from verl.utils.py_functional import timeout_limit class StarPlacementPuzzleVerifier(Verifier): """ @@ -19,81 +18,84 @@ def verify(self, data: Data, test_solution: str): @return: 回答是否正确的布尔值 """ try: - star_coords = self.extract_answer(test_solution) - # 获取游戏元数据 - metadata = data.metadata - n = metadata['n'] - k = metadata['k'] - region_grid = metadata['region_grid'] - - # print(f"验证: 游戏规模 {n}×{n}, 每行/列/区域星星数量: {k}") - - # 检查是否有有效的星星坐标 - if not star_coords: - # print("无法从回答中提取有效的星星坐标") - return False - - # 创建一个表示星星位置的网格 - star_grid = [[0 for _ in range(n)] for _ in range(n)] - for region, coords in star_coords.items(): - for coord in coords: - row, col = coord - if row < 0 or row >= n or col < 0 or col >= n: - # print(f"无效坐标: ({row},{col}) - 超出网格范围") - return False - star_grid[row][col] = 1 - - # 打印星星网格以便调试 - # print("星星网格:") - for row in star_grid: - print(''.join(['* ' if cell == 1 else '. ' for cell in row])) - - # 1. 检查每行是否有k颗星星 - for i in range(n): - stars_in_row = sum(star_grid[i]) - if stars_in_row != k: - # print(f"行 {i+1} 有 {stars_in_row} 颗星星,应该有 {k} 颗") - return False - - # 2. 检查每列是否有k颗星星 - for j in range(n): - stars_in_col = sum(star_grid[i][j] for i in range(n)) - if stars_in_col != k: - # print(f"列 {j+1} 有 {stars_in_col} 颗星星,应该有 {k} 颗") - return False - - # 3. 检查每个区域是否有k颗星星 - regions = {} - for i in range(n): - for j in range(n): - region = region_grid[i][j] - if region not in regions: - regions[region] = [] - regions[region].append((i, j)) - - for region, cells in regions.items(): - stars_in_region = sum(star_grid[i][j] for i, j in cells) - if stars_in_region != k: - # print(f"区域 {region} 有 {stars_in_region} 颗星星,应该有 {k} 颗") + @timeout_limit(seconds=10) + def _verify_with_timeout(): + star_coords = self.extract_answer(test_solution) + # 获取游戏元数据 + metadata = data.metadata + n = metadata['n'] + k = metadata['k'] + region_grid = metadata['region_grid'] + + # print(f"验证: 游戏规模 {n}×{n}, 每行/列/区域星星数量: {k}") + + # 检查是否有有效的星星坐标 + if not star_coords: + # print("无法从回答中提取有效的星星坐标") return False - - # 4. 检查星星是否互不相邻(水平、垂直、对角线) - for i in range(n): + + # 创建一个表示星星位置的网格 + star_grid = [[0 for _ in range(n)] for _ in range(n)] + for region, coords in star_coords.items(): + for coord in coords: + row, col = coord + if row < 0 or row >= n or col < 0 or col >= n: + # print(f"无效坐标: ({row},{col}) - 超出网格范围") + return False + star_grid[row][col] = 1 + + # 打印星星网格以便调试 + # print("星星网格:") + # for row in star_grid: + # print(''.join(['* ' if cell == 1 else '. ' for cell in row])) + + # 1. 检查每行是否有k颗星星 + for i in range(n): + stars_in_row = sum(star_grid[i]) + if stars_in_row != k: + # print(f"行 {i+1} 有 {stars_in_row} 颗星星,应该有 {k} 颗") + return False + + # 2. 检查每列是否有k颗星星 for j in range(n): - if star_grid[i][j] == 1: - # 检查周围8个方向 - for di in [-1, 0, 1]: - for dj in [-1, 0, 1]: - if di == 0 and dj == 0: - continue # 跳过自身 - ni, nj = i + di, j + dj - if 0 <= ni < n and 0 <= nj < n and star_grid[ni][nj] == 1: - # print(f"星星在 ({i},{j}) 与星星在 ({ni},{nj}) 相邻") - return False - - # 所有检查通过 - # print("所有验证规则通过!") - return True + stars_in_col = sum(star_grid[i][j] for i in range(n)) + if stars_in_col != k: + # print(f"列 {j+1} 有 {stars_in_col} 颗星星,应该有 {k} 颗") + return False + + # 3. 检查每个区域是否有k颗星星 + regions = {} + for i in range(n): + for j in range(n): + region = region_grid[i][j] + if region not in regions: + regions[region] = [] + regions[region].append((i, j)) + + for region, cells in regions.items(): + stars_in_region = sum(star_grid[i][j] for i, j in cells) + if stars_in_region != k: + # print(f"区域 {region} 有 {stars_in_region} 颗星星,应该有 {k} 颗") + return False + + # 4. 检查星星是否互不相邻(水平、垂直、对角线) + for i in range(n): + for j in range(n): + if star_grid[i][j] == 1: + # 检查周围8个方向 + for di in [-1, 0, 1]: + for dj in [-1, 0, 1]: + if di == 0 and dj == 0: + continue # 跳过自身 + ni, nj = i + di, j + dj + if 0 <= ni < n and 0 <= nj < n and star_grid[ni][nj] == 1: + # print(f"星星在 ({i},{j}) 与星星在 ({ni},{nj}) 相邻") + return False + + # 所有检查通过 + # print("所有验证规则通过!") + return True + return _verify_with_timeout() except Exception as e: print(f"Verification error (StarPlacementPuzzle): {e}") diff --git a/verl/utils/reward_score/synlogic/time_sequence_verifier.py b/verl/utils/reward_score/synlogic/time_sequence_verifier.py index 711a43d8..6133f437 100644 --- a/verl/utils/reward_score/synlogic/time_sequence_verifier.py +++ b/verl/utils/reward_score/synlogic/time_sequence_verifier.py @@ -3,6 +3,7 @@ from .data import Data from .verifier import Verifier import re +from verl.utils.py_functional import timeout_limit class TimeSequenceVerifier(Verifier): """ @@ -17,35 +18,39 @@ def verify(self, data: Data, test_solution: str): @return: 回答是否正确的布尔值 """ try: - test_answer = self.extract_answer(test_solution) - # 解析元数据 - metadata = data.metadata - true_answers = metadata['records']['answers'] - - # 解析模型给出的列表 - try: - test_list = json.loads(test_answer.replace(",", ",")) - except: - print(f"NOTE!!! parse error!!!! (TimeSequence 1): {e}") - return False - - try: - if test_list[0]!=true_answers['answer_maxLen']: - # print(f"最长会议时间不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]") + @timeout_limit(seconds=10) + def _verify_with_timeout(): + test_answer = self.extract_answer(test_solution) + # 解析元数据 + metadata = data.metadata + true_answers = metadata["records"]["answers"] + + # 解析模型给出的列表 + try: + test_list = json.loads(test_answer.replace(",", ",")) + except Exception as e: + # print(f"NOTE!!! parse error!!!! (TimeSequence 1): {e}") return False - if test_list[1]!=true_answers['answer_nums']: - # print(f"可选会议数量不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]") + + try: + if test_list[0] != true_answers["answer_maxLen"]: + # print(f"最长会议时间不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]") + return False + if test_list[1] != true_answers["answer_nums"]: + # print(f"可选会议数量不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]") + return False + except Exception as e: + # print(f"NOTE!!! parse error!!!! (TimeSequence 2): {e}") return False - except: - print(f"NOTE!!! parse error!!!! (TimeSequence 2): {e}") - return False - - # 所有检查都通过 - # print("验证结果: 正确") - return True + + # 所有检查都通过 + # print("验证结果: 正确") + return True + + return _verify_with_timeout() except Exception as e: print(f"Verification error (TimeSequence): {e}") - return False + return False def extract_answer(self, test_solution: str): """ diff --git a/verl/utils/reward_score/synlogic/web_of_lies_verifier.py b/verl/utils/reward_score/synlogic/web_of_lies_verifier.py index 94be44fe..751ea9c3 100644 --- a/verl/utils/reward_score/synlogic/web_of_lies_verifier.py +++ b/verl/utils/reward_score/synlogic/web_of_lies_verifier.py @@ -1,6 +1,7 @@ import re from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +from verl.utils.py_functional import timeout_limit class WebOfLiesVerifier(Verifier): """ @@ -15,35 +16,37 @@ def verify(self, data: Data, test_solution: str): @return: 回答是否正确的布尔值 """ try: - test_answer = self.extract_answer(test_solution) - # 获取预期答案和测试答案 - expected_answer = data.answer.lower() - - # 清理测试答案 - test_answer = test_answer.lower() - - # 提取预期答案中的真假值 - expected_truths = self._parse_answer(expected_answer) - - # 提取测试答案中的真假值 - test_truths = self._parse_answer(test_answer) - - # print(f"验证: 预期答案={expected_truths}, 模型答案={test_truths}") - - # 检查答案列表长度是否匹配 - if len(expected_truths) != len(test_truths): - # print(f"验证失败: 答案长度不匹配,预期 {len(expected_truths)},实际 {len(test_truths)}") - return False - - # 检查每个位置的答案是否匹配 - for i, (expected, actual) in enumerate(zip(expected_truths, test_truths)): - if expected != actual: - # print(f"验证失败: 第 {i+1} 个答案不匹配,预期 {expected},实际 {actual}") + @timeout_limit(seconds=10) + def _verify_with_timeout(): + test_answer = self.extract_answer(test_solution) + # 获取预期答案和测试答案 + expected_answer = data.answer.lower() + + # 清理测试答案 + test_answer = test_answer.lower() + + # 提取预期答案中的真假值 + expected_truths = self._parse_answer(expected_answer) + + # 提取测试答案中的真假值 + test_truths = self._parse_answer(test_answer) + + # print(f"验证: 预期答案={expected_truths}, 模型答案={test_truths}") + + # 检查答案列表长度是否匹配 + if len(expected_truths) != len(test_truths): + # print(f"验证失败: 答案长度不匹配,预期 {len(expected_truths)},实际 {len(test_truths)}") return False - - # print("验证成功: 所有答案匹配") - return True - + + # 检查每个位置的答案是否匹配 + for i, (expected, actual) in enumerate(zip(expected_truths, test_truths)): + if expected != actual: + # print(f"验证失败: 第 {i+1} 个答案不匹配,预期 {expected},实际 {actual}") + return False + + # print("验证成功: 所有答案匹配") + return True + return _verify_with_timeout() except Exception as e: print(f"Verification error (WebOfLies): {e}") return False diff --git a/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py b/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py index 8216f760..3f2366b2 100644 --- a/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py +++ b/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py @@ -1,6 +1,7 @@ import re from .data import Data from .verifier import Verifier +from verl.utils.py_functional import timeout_limit class WordSortingMistakeVerifier(Verifier): """ @@ -8,19 +9,22 @@ class WordSortingMistakeVerifier(Verifier): """ def verify(self, data: Data, test_answer: str): try: - ground_truth = data.answer if data.answer is not None else "No" - parsed_answer = self.extract_answer(test_answer) - - if parsed_answer is None: - return False - - if parsed_answer.isdigit(): - try: - return int(parsed_answer) == int(ground_truth) - except Exception as e: + @timeout_limit(seconds=10) + def _verify_with_timeout(): + ground_truth = data.answer if data.answer is not None else "No" + parsed_answer = self.extract_answer(test_answer) + + if parsed_answer is None: return False - else: - return parsed_answer.lower() == ground_truth.lower() + + if parsed_answer.isdigit(): + try: + return int(parsed_answer) == int(ground_truth) + except Exception as e: + return False + else: + return parsed_answer.lower() == str(ground_truth).lower() + return _verify_with_timeout() except Exception as e: print(f"NOTE!!! parse error!!!! (WordSortingMistake): {e}") return False diff --git a/verl/utils/reward_score/synlogic/word_sorting_verifier.py b/verl/utils/reward_score/synlogic/word_sorting_verifier.py index 56758108..df3d6cf2 100644 --- a/verl/utils/reward_score/synlogic/word_sorting_verifier.py +++ b/verl/utils/reward_score/synlogic/word_sorting_verifier.py @@ -1,6 +1,7 @@ import re from .data import Data from .verifier import Verifier +from verl.utils.py_functional import timeout_limit class WordSortingVerifier(Verifier): """ @@ -13,12 +14,15 @@ def str2list(self, answer_str): def verify(self, data: Data, test_answer: str): try: - ground_truth = self.str2list(data.answer) - parsed_answer = self.str2list(self.extract_answer(test_answer)) - - if parsed_answer is None: - return False - return parsed_answer == ground_truth + @timeout_limit(seconds=10) + def _verify_with_timeout(): + ground_truth = self.str2list(data.answer) + parsed_answer = self.str2list(self.extract_answer(test_answer)) + + if parsed_answer is None: + return False + return parsed_answer == ground_truth + return _verify_with_timeout() except Exception as e: print(f"NOTE!!! parse error!!!! (WordSorting): {e}") diff --git a/verl/utils/reward_score/synlogic/wordscapes_verifier.py b/verl/utils/reward_score/synlogic/wordscapes_verifier.py index 2f30a784..796f06c0 100644 --- a/verl/utils/reward_score/synlogic/wordscapes_verifier.py +++ b/verl/utils/reward_score/synlogic/wordscapes_verifier.py @@ -5,155 +5,173 @@ import json import re from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +from verl.utils.py_functional import timeout_limit debug_mode = False + class WordscapesVerifier(Verifier): """ Verifier for Wordscapes game """ + def verify(self, data, test_solution: str): """ Verify whether the test answer is consistent with the gold answer - + Args: data: WordscapesData test_solution: str containing the solution - + Returns: float: Score between 0 and 1 """ try: - extracted_answer = self.extract_answer(test_solution) - if not extracted_answer: - print("NOTE!!! parse error!!!! (Wordscapes): {e}") - return False - - if debug_mode: - for row in extracted_answer: - print(" ".join(cell if cell != " " else "_" for cell in row)) - - # Get grid, across_words, and down_words from data - grid = data.metadata["grid"] - across_words = data.metadata["across_words"] - down_words = data.metadata["down_words"] - - # Validate grid dimensions - if len(extracted_answer) != len(grid): - # print(f"Grid height mismatch: expected {len(grid)}, got {len(extracted_answer)}") - return False - - for i in range(len(grid)): - if len(extracted_answer[i]) != len(grid[i]): - # print(f"Grid width mismatch at row {i}: expected {len(grid[i])}, got {len(extracted_answer[i])}") + + @timeout_limit(seconds=10) + def _verify_with_timeout(): + extracted_answer = self.extract_answer(test_solution) + if not extracted_answer: + # print("NOTE!!! parse error!!!! (Wordscapes): {e}") return False - - # Check if the answer respects the grid layout (X for letters, 0 for empty) - for i in range(len(grid)): - for j in range(len(grid[i])): - if grid[i][j] == "0" and extracted_answer[i][j].strip(): - # print(f"Expected empty space at position ({i},{j}), got '{extracted_answer[i][j]}'") - return False - if grid[i][j] == "X" and not extracted_answer[i][j].strip(): - # print(f"Expected letter at position ({i},{j}), got empty space") + + if debug_mode: + for row in extracted_answer: + print(" ".join(cell if cell != " " else "_" for cell in row)) + + # Get grid, across_words, and down_words from data + grid = data.metadata["grid"] + across_words = data.metadata["across_words"] + down_words = data.metadata["down_words"] + + # Validate grid dimensions + if len(extracted_answer) != len(grid): + # print(f"Grid height mismatch: expected {len(grid)}, got {len(extracted_answer)}") + return False + + for i in range(len(grid)): + if len(extracted_answer[i]) != len(grid[i]): + # print(f"Grid width mismatch at row {i}: expected {len(grid[i])}, got {len(extracted_answer[i])}") return False - - # Verify across words - for word in across_words: - found = False - for i in range(len(extracted_answer)): - row_str = ''.join(extracted_answer[i]).replace(' ', '').lower() - if word.lower() in row_str: - found = True - break - if not found and word: - # print(f"Across word '{word}' not found in the grid") - return 0 - - # Verify down words - for word in down_words: - found = False - for j in range(len(extracted_answer[0])): - col = [] + + # Check if the answer respects the grid layout (X for letters, 0 for empty) + for i in range(len(grid)): + for j in range(len(grid[i])): + if grid[i][j] == "0" and extracted_answer[i][j].strip(): + # print(f"Expected empty space at position ({i},{j}), got '{extracted_answer[i][j]}'") + return False + if grid[i][j] == "X" and not extracted_answer[i][j].strip(): + # print(f"Expected letter at position ({i},{j}), got empty space") + return False + + # Verify across words + for word_info in across_words: + found = False + target_word = word_info["word"] if isinstance(word_info, dict) else word_info for i in range(len(extracted_answer)): - if j < len(extracted_answer[i]): - col.append(extracted_answer[i][j]) - col_str = ''.join(col).replace(' ', '').lower() - if word.lower() in col_str: - found = True - break - if not found and word: # Only check if word is not empty - # print(f"Down word '{word}' not found in the grid") - return False - - # All checks passed - return True + row_str = "".join(extracted_answer[i]).replace(" ", "").lower() + if target_word.lower() in row_str: + found = True + break + if not found: + # print(f"Across word '{target_word}' not found in the grid") + return False + + # Verify down words + for word_info in down_words: + found = False + target_word = word_info["word"] if isinstance(word_info, dict) else word_info + for j in range(len(extracted_answer[0])): + col = [] + for i in range(len(extracted_answer)): + if j < len(extracted_answer[i]): + col.append(extracted_answer[i][j]) + col_str = "".join(col).replace(" ", "").lower() + if target_word.lower() in col_str: + found = True + break + if not found: + # print(f"Down word '{target_word}' not found in the grid") + return False + + # All checks passed + return True + + return _verify_with_timeout() except Exception as e: print(f"Verification error (Wordscapes): {e}") return False - + def extract_answer(self, test_solution: str): """ Extract the answer from the test solution - + Args: test_solution: str - + Returns: list: 2D grid of the answer or None if extraction fails """ try: # Remove thoughts if present - if THOUGHT_DELIMITER_START in test_solution and THOUGHT_DELIMITER_END in test_solution: + if ( + THOUGHT_DELIMITER_START in test_solution + and THOUGHT_DELIMITER_END in test_solution + ): # Extract only the part after the thoughts thought_end_pos = test_solution.rfind(THOUGHT_DELIMITER_END) if thought_end_pos >= 0: - test_solution = test_solution[thought_end_pos + len(THOUGHT_DELIMITER_END):] - + test_solution = test_solution[ + thought_end_pos + len(THOUGHT_DELIMITER_END) : + ] + # Clean up the response and find the grid pattern # Look for a pattern like [[...]] or [[[...]]] - grid_pattern = re.search(r'\[\s*\[(?:\s*\[)?(.+?)(?:\]\s*)?\]\s*\]', test_solution, re.DOTALL) + grid_pattern = re.search( + r"\[\s*\[(?:\s*\[)?(.+?)(?:\]\s*)?\]\s*\]", test_solution, re.DOTALL + ) if not grid_pattern: return None - + grid_text = grid_pattern.group(1) - + # Handle various formats rows = [] - + # Check if rows are separated by commas - split_rows = re.split(r'\],\s*\[', grid_text) - + split_rows = re.split(r"\],\s*\[", grid_text) + for row_text in split_rows: # Clean the row text and extract characters - row_text = row_text.strip().strip('[],') - + row_text = row_text.strip().strip("[],") + # Extract quoted characters: "X" or 'X' or just X chars = [] - + # Look for quoted strings or standalone characters - char_matches = re.findall(r'\"([^\"]*)\"|\'([^\']*)\'|([^,\s]+)', row_text) - + char_matches = re.findall( + r"\"([^\"]*)\"|\'([^\']*)\'|([^,\s]+)", row_text + ) + for match in char_matches: # Take the first non-empty group from each match char = next((x for x in match if x), "") - + # Handle numeric or empty values (0, "", '') if char == "0" or char == "": char = " " - + chars.append(char) - + if chars: # Only add non-empty rows rows.append(chars) - + # Make sure we have a valid grid if not rows or not all(rows): return None - + return rows - + except Exception as e: print(f"NOTE!!! parse error!!!! (Wordscapes): {e}") return None - \ No newline at end of file From dc4300cc8973f736e917dc9cc4750d41b8c0e70e Mon Sep 17 00:00:00 2001 From: Varad Pimpalkhute Date: Sun, 11 Jan 2026 20:23:49 +0000 Subject: [PATCH 16/20] fix reward functions --- scripts/train/k2p_hero_grpo_newData.sh | 14 +- scripts/train/k2p_hero_grpo_newData_7B.sh | 362 ++++++++++++++++++ verl/utils/py_functional.py | 11 + .../reward_score/math_llm_judge/__init__.py | 10 +- .../reward_score/math_llm_judge/grader.py | 21 +- verl/utils/reward_score/naive_dapo.py | 15 +- .../utils/reward_score/prime_math/__init__.py | 2 +- verl/workers/reward_manager/async_mp.py | 5 +- verl/workers/reward_manager/dapo.py | 1 + 9 files changed, 403 insertions(+), 38 deletions(-) create mode 100644 scripts/train/k2p_hero_grpo_newData_7B.sh diff --git a/scripts/train/k2p_hero_grpo_newData.sh b/scripts/train/k2p_hero_grpo_newData.sh index fcee39a9..a387f442 100644 --- a/scripts/train/k2p_hero_grpo_newData.sh +++ b/scripts/train/k2p_hero_grpo_newData.sh @@ -15,8 +15,10 @@ # =================== Frequently Used Variables =================== RESUME_CKPT_DIR_NAME="grpo-stage2-k2pRL-easy50k-7domains-415354" # Fill in the checkpoint directory name to resume from, otherwise from scratch -export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain -export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty +# export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +# export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-227:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-291:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty # =================== Cluster Environment =================== export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ @@ -172,6 +174,7 @@ test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${a # =================== Model =================== # BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT # BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface # BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface @@ -287,7 +290,7 @@ offload=True actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ actor_rollout_ref.actor.clip_ratio_c=10.0 \ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=48000 \ actor_rollout_ref.actor.strategy="fsdp2" \ actor_rollout_ref.actor.optim.lr=5e-7 \ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ @@ -346,17 +349,16 @@ offload=True reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ reward_model.overlong_buffer.len=${overlong_buffer_len} \ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.num_processes=64 \ trainer.logger=['console','wandb'] \ trainer.project_name=${WANDB_PROJECT} \ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ trainer.val_before_train=False \ trainer.n_gpus_per_node=8 \ trainer.nnodes=$worker_num \ - trainer.save_freq=1 \ + trainer.save_freq=5 \ trainer.test_freq=5 \ trainer.total_epochs=5 \ - trainer.log_val_generations=0 \ trainer.resume_mode=auto \ trainer.max_actor_ckpt_to_keep=3 - # data.id_val_files="$id_val_files" \ # trainer.log_val_generations=50 \ \ No newline at end of file diff --git a/scripts/train/k2p_hero_grpo_newData_7B.sh b/scripts/train/k2p_hero_grpo_newData_7B.sh new file mode 100644 index 00000000..694fab7d --- /dev/null +++ b/scripts/train/k2p_hero_grpo_newData_7B.sh @@ -0,0 +1,362 @@ +#!/bin/bash +#SBATCH --job-name=grpo-stage2-k2pRL-easy50k-7domains +#SBATCH --nodes=64 +#SBATCH --ntasks=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=higherprio +# commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410] + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="grpo-stage2-k2pRL-easy50k-7domains-415354" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800000 +export OMPI_MCA_coll_hcoll_enable=0 \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 \ +NCCL_NVLS_ENABLE=0 + + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== + +# Training Data Configuration +DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_1" +train_file_list=() +id_val_file_list=() + +iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet" + +# List of datasets to include (filename only) +# Comment out lines to exclude specific datasets +dataset_names=( + "codegen__deduped_leetcode2k_2.4k.parquet" + "codegen__deduped_livecodebench_599.parquet" + "codegen__deduped_primeintellect_9.6k.parquet" + "codegen__deduped_taco_11.1k.parquet" + "ifbench__fixed_85.6k.parquet" + "math__combined_118.2k.part1.parquet" + "math__combined_118.2k.part2.parquet" + "omni_math_4.43k_dedup.parquet" + "stem__nemotron_13.3k.parquet" + "stem__web_31.7k.parquet" + "table__hitab_7.4k.parquet" + "table__multihier_2.9k.parquet" +) +# "simulation__codeio_fixed_12.1k.parquet" +# "logic__arcagi1_297.parquet" +# "logic__arcagi2_653.parquet" +# "logic__barc_3.4k.parquet" +# "logic__graph_logical_dataset_1.4k.parquet" +# "logic__ordering_puzzle_dataset_2.9k.parquet" +# "logic__reasoning_gym_40.6k.parquet" +# "logic__synlogic_12.1k.parquet" +# "logic__zebra_puzzle_dataset_5.0k.parquet" + +echo "Collecting training files from ${DATA_MIX_DIR}..." + +# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions" "easy_questions" +for dataset in "${dataset_names[@]}"; do + for subdir in "main_questions"; do + file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" + if [ -f "$file_path" ]; then + echo "Adding: $file_path" + train_file_list+=("'$file_path'") + fi + done +done + +# for dataset in "${dataset_names[@]}"; do +# for subdir in "131k_context_questions"; do +# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" +# if [ -f "$file_path" ]; then +# echo "Adding: $file_path" +# id_val_file_list+=("'$file_path'") +# fi +# done +# done +# id_val_file_list+=("'$iq400_path'") + +# Join with comma to form Python list string +IFS=, +train_files="[${train_file_list[*]}]" +# id_val_files="[${id_val_file_list[*]}]" +unset IFS + +echo "Total training files found: ${#train_file_list[@]}" +# echo "Total ID validation files found: ${#id_val_file_list[@]}" + +# Test Data Configuration +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet +# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet + +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# Instruction follow (test) +if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet + +# Focused data mixture (math, code, stem) +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Full data mixture (uncomment to use) +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}', + + +# =================== Model =================== +# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k +BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} +# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 64)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 12)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=256 # model grad update batchsize + +# Algorithm +temperature=1.2 +val_temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=16 # Reduced from 32 to reduce memory pressure +gen_tp=4 +gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=${offload} \ + actor_rollout_ref.model.use_liger=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=1 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=0 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 + # data.id_val_files="$id_val_files" \ + # trainer.log_val_generations=50 \ \ No newline at end of file diff --git a/verl/utils/py_functional.py b/verl/utils/py_functional.py index 159c2589..b09ec3dd 100644 --- a/verl/utils/py_functional.py +++ b/verl/utils/py_functional.py @@ -111,8 +111,19 @@ def wrapper_mp(*args, **kwargs): if process.is_alive(): process.terminate() process.join(timeout=0.5) # Give it a moment to terminate + if process.is_alive(): + try: + process.kill() + print(f"Warning: Escalated to force kill process {process.pid}") + except AttributeError: + os.kill(process.pid, signal.SIGKILL) + print(f"Fall back to very old way of killing process") + process.join(timeout=0.5) + if process.is_alive(): print(f"Warning: Process {process.pid} did not terminate gracefully after timeout.") + raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds and could not be killed (pid={process.pid})!") + # Update function name in error message if needed (optional but good practice) raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds (multiprocessing)!") diff --git a/verl/utils/reward_score/math_llm_judge/__init__.py b/verl/utils/reward_score/math_llm_judge/__init__.py index e0fba8d5..f725986f 100644 --- a/verl/utils/reward_score/math_llm_judge/__init__.py +++ b/verl/utils/reward_score/math_llm_judge/__init__.py @@ -40,7 +40,7 @@ import math import sympy -from pylatexenc import latex2text +# from pylatexenc import latex2text from sympy.parsing import sympy_parser import requests from verl.utils.py_functional import timeout_limit @@ -48,6 +48,9 @@ from . import math_normalize from .grader import math_equal +import requests +from verl.utils.py_functional import timeout_limit + # import math_normalize # from grader import math_equal @@ -57,8 +60,6 @@ TUPLE_CHARS = "()[]" - - def _sympy_parse(expr: str): """Parses an expression with sympy.""" py_expr = expr.replace("^", "**") @@ -73,7 +74,7 @@ def _parse_latex(expr: str) -> str: expr = expr.replace("\\tfrac", "\\frac") expr = expr.replace("\\dfrac", "\\frac") expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. - expr = latex2text.LatexNodes2Text().latex_to_text(expr) + # expr = latex2text.LatexNodes2Text().latex_to_text(expr) # Replace the specific characters that this parser uses. expr = expr.replace("√", "sqrt") @@ -377,6 +378,7 @@ def llm_check_answer(model_output: str, ground_truth: str, question: str) -> boo # use llm to check if the answer is correct # url = "http://176.56.200.81:30000/v1/chat/completions" + import os url_base = os.getenv("MATH_LLM_JUDGE_URL") if not url_base: raise ValueError("MATH_LLM_JUDGE_URL is not set") diff --git a/verl/utils/reward_score/math_llm_judge/grader.py b/verl/utils/reward_score/math_llm_judge/grader.py index 87d8bac6..5eaa2267 100644 --- a/verl/utils/reward_score/math_llm_judge/grader.py +++ b/verl/utils/reward_score/math_llm_judge/grader.py @@ -92,7 +92,6 @@ - https://github.com/openai/prm800k """ -import contextlib import re import math from math import isclose @@ -304,18 +303,16 @@ def math_equal(prediction: Union[bool, float, str], except Exception: pass - return symbolic_equal(prediction, reference, tolerance, timeout) + return symbolic_equal(prediction, reference, tolerance) -def symbolic_equal(a, b, tolerance, timeout=10.0): +@timeout_limit(seconds=10) +def symbolic_equal(a, b, tolerance): def _parse(s): for f in [parse_expr, parse_latex]: try: - @timeout_limit(seconds=timeout) - def _parse_with_timeout(): - return f(s) - return _parse_with_timeout() + return f(s) except Exception: pass return s @@ -324,19 +321,13 @@ def _parse_with_timeout(): b = _parse(b) try: - @timeout_limit(seconds=timeout) - def _simplify_with_timeout(): - return simplify(a - b) == 0 - if _simplify_with_timeout(): + if simplify(a - b) == 0: return True except Exception: pass try: - @timeout_limit(seconds=timeout) - def _numeric_equal_with_timeout(): - return isclose(N(a), N(b), rel_tol=tolerance) - if _numeric_equal_with_timeout(): + if isclose(N(a), N(b), rel_tol=tolerance): return True except Exception: pass diff --git a/verl/utils/reward_score/naive_dapo.py b/verl/utils/reward_score/naive_dapo.py index 8ad39d01..739b9c9e 100644 --- a/verl/utils/reward_score/naive_dapo.py +++ b/verl/utils/reward_score/naive_dapo.py @@ -18,12 +18,13 @@ import math import sympy -from pylatexenc import latex2text +# from pylatexenc import latex2text from sympy.parsing import sympy_parser from verl.utils.py_functional import timeout_limit from .prime_math import math_normalize from .prime_math.grader import math_equal +from verl.utils.py_functional import timeout_limit # Constants for normalization @@ -393,10 +394,8 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) is_correct = False else: - try: - is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) - except TimeoutError: - is_correct = False + is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + # is_correct = False if not is_correct: break @@ -477,11 +476,7 @@ def compute_score(solution_str: str, ground_truth: str, extra_info: dict) -> flo if "\\pi" in extracted_model_output or "\\pi" in ground_truth: equivs = [] for pi in [math.pi, 3.14]: - equivs.append( - math_equal( - extracted_model_output, ground_truth, timeout=True, pi=pi - ) - ) + equivs.append(math_equal(extracted_model_output, ground_truth, timeout=True, pi=pi)) correct = any(equivs) else: correct = math_equal(extracted_model_output, ground_truth, timeout=True) diff --git a/verl/utils/reward_score/prime_math/__init__.py b/verl/utils/reward_score/prime_math/__init__.py index 8d9d273e..50b02b3f 100644 --- a/verl/utils/reward_score/prime_math/__init__.py +++ b/verl/utils/reward_score/prime_math/__init__.py @@ -24,7 +24,7 @@ import re import sympy -from pylatexenc import latex2text +# from pylatexenc import latex2text from sympy.parsing import sympy_parser from verl.utils.py_functional import timeout_limit diff --git a/verl/workers/reward_manager/async_mp.py b/verl/workers/reward_manager/async_mp.py index 09b913c7..7150b66c 100644 --- a/verl/workers/reward_manager/async_mp.py +++ b/verl/workers/reward_manager/async_mp.py @@ -149,6 +149,7 @@ def __init__( overlong_buffer_cfg=None, batch_size=2048, shuffle_batch=True, + num_processes=32, **kwargs, ) -> None: self.tokenizer = tokenizer @@ -159,7 +160,7 @@ def __init__( self.max_resp_len = max_resp_len self.batch_size = batch_size self.shuffle_batch = shuffle_batch - + self.num_processes = num_processes if self.overlong_buffer_cfg is not None: assert self.max_resp_len is not None, ( f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None" @@ -264,7 +265,7 @@ def __call__(self, data: DataProto, return_dict: bool = False): solutions, ground_truths, extra_infos, - num_processes=64, + num_processes=self.num_processes, batch_size=self.batch_size, shuffle=self.shuffle_batch, ) diff --git a/verl/workers/reward_manager/dapo.py b/verl/workers/reward_manager/dapo.py index d8b6b474..9d3aa6b0 100644 --- a/verl/workers/reward_manager/dapo.py +++ b/verl/workers/reward_manager/dapo.py @@ -109,6 +109,7 @@ def __call__(self, data: DataProto, return_dict: bool = False): reward_extra_info[key].append(value) else: score = result + reward_extra_info["score"].append(score) reward_extra_info["acc"].append(score) reward = score From 8e671a082454271c53353e930618c9afefc0d09a Mon Sep 17 00:00:00 2001 From: Varad Pimpalkhute Date: Sun, 11 Jan 2026 20:24:59 +0000 Subject: [PATCH 17/20] delete 7B script --- scripts/train/k2p_hero_grpo_newData_7B.sh | 362 ---------------------- 1 file changed, 362 deletions(-) delete mode 100644 scripts/train/k2p_hero_grpo_newData_7B.sh diff --git a/scripts/train/k2p_hero_grpo_newData_7B.sh b/scripts/train/k2p_hero_grpo_newData_7B.sh deleted file mode 100644 index 694fab7d..00000000 --- a/scripts/train/k2p_hero_grpo_newData_7B.sh +++ /dev/null @@ -1,362 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=grpo-stage2-k2pRL-easy50k-7domains -#SBATCH --nodes=64 -#SBATCH --ntasks=64 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:8 -#SBATCH --cpus-per-task=96 -#SBATCH --mem=0 -#SBATCH --output=slurm/%x-%j.log -#SBATCH --error=slurm/%x-%j.log -#SBATCH --exclusive -#SBATCH --time=720:00:00 -#SBATCH --partition=higherprio -# commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410] - -# =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="grpo-stage2-k2pRL-easy50k-7domains-415354" # Fill in the checkpoint directory name to resume from, otherwise from scratch -export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain -export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty - -# =================== Cluster Environment =================== -export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ -export ROCR_VISIBLE_DEVICES=None -export NCCL_TIMEOUT_SECONDS=4800000 -export OMPI_MCA_coll_hcoll_enable=0 \ -TORCH_NCCL_ENABLE_MONITORING=0 \ -CUDA_DEVICE_ORDER=PCI_BUS_ID \ -NCCL_SOCKET_IFNAME=eth0 \ -UCX_TLS=rc \ -UCX_NET_DEVICES=mlx5_ib0:1 \ -NCCL_DEBUG=WARN \ -NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ -NCCL_IB_PCI_RELAXED_ORDERING=1 \ -NCCL_IB_QPS_PER_CONNECTION=4 \ -NCCL_IGNORE_CPU_AFFINITY=1 \ -NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ -NCCL_PXN_DISABLE=1 \ -NCCL_MIN_NCHANNELS=32 \ -SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ -SHARP_COLL_ENABLE_SAT=1 \ -SHARP_COLL_LOG_LEVEL=3 \ -SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ -NCCL_COLLNET_ENABLE=1 \ -NCCL_NVLS_ENABLE=0 - - -# Get the list of allocated nodes -nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) -echo "Nodes to check: ${nodes[@]}" - -# We'll track PIDs so we can wait on them and detect errors -declare -A pids -export head_node=${nodes[0]} -head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) -port=6379 -address_head=$head_node_ip:$port - -export worker_num=$SLURM_NNODES -export HYDRA_FULL_ERROR=1 -export VLLM_USE_V1=1 - -# =================== Data Mixture =================== - -# Training Data Configuration -DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_1" -train_file_list=() -id_val_file_list=() - -iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet" - -# List of datasets to include (filename only) -# Comment out lines to exclude specific datasets -dataset_names=( - "codegen__deduped_leetcode2k_2.4k.parquet" - "codegen__deduped_livecodebench_599.parquet" - "codegen__deduped_primeintellect_9.6k.parquet" - "codegen__deduped_taco_11.1k.parquet" - "ifbench__fixed_85.6k.parquet" - "math__combined_118.2k.part1.parquet" - "math__combined_118.2k.part2.parquet" - "omni_math_4.43k_dedup.parquet" - "stem__nemotron_13.3k.parquet" - "stem__web_31.7k.parquet" - "table__hitab_7.4k.parquet" - "table__multihier_2.9k.parquet" -) -# "simulation__codeio_fixed_12.1k.parquet" -# "logic__arcagi1_297.parquet" -# "logic__arcagi2_653.parquet" -# "logic__barc_3.4k.parquet" -# "logic__graph_logical_dataset_1.4k.parquet" -# "logic__ordering_puzzle_dataset_2.9k.parquet" -# "logic__reasoning_gym_40.6k.parquet" -# "logic__synlogic_12.1k.parquet" -# "logic__zebra_puzzle_dataset_5.0k.parquet" - -echo "Collecting training files from ${DATA_MIX_DIR}..." - -# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions" "easy_questions" -for dataset in "${dataset_names[@]}"; do - for subdir in "main_questions"; do - file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" - if [ -f "$file_path" ]; then - echo "Adding: $file_path" - train_file_list+=("'$file_path'") - fi - done -done - -# for dataset in "${dataset_names[@]}"; do -# for subdir in "131k_context_questions"; do -# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" -# if [ -f "$file_path" ]; then -# echo "Adding: $file_path" -# id_val_file_list+=("'$file_path'") -# fi -# done -# done -# id_val_file_list+=("'$iq400_path'") - -# Join with comma to form Python list string -IFS=, -train_files="[${train_file_list[*]}]" -# id_val_files="[${id_val_file_list[*]}]" -unset IFS - -echo "Total training files found: ${#train_file_list[@]}" -# echo "Total ID validation files found: ${#id_val_file_list[@]}" - -# Test Data Configuration -TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len -# Math (test) -math_test_path=${TEST_DATA_DIR}/math__math_500.parquet -aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet -aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet -amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet - -# Code (test) -humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet -mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet -livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet - -# Logic (test) -zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet -reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet -synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet -arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet -# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet -# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet - -# Table (test) -multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet -hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet - -# Stem (test) -nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet -gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet -supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet - -# Instruction follow (test) -if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet -if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet - -# Focused data mixture (math, code, stem) -# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" -# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" - -# Full data mixture (uncomment to use) -test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}', - - -# =================== Model =================== -# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT -# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) -# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k -BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface - -# =================== Logging =================== -WANDB_PROJECT=k2plus_rl -WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} -# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491" - -# If RESUME_CKPT_DIR is not empty, resume from the checkpoint -if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then - WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" -fi - - -# =================== Ray start =================== -# ray stop at all nodes -srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop - -sleep 10 -# Remove existing Ray cluster -srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster - -# Start Ray head node -srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ - env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ - ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & - -sleep 10 - -# Start Ray worker nodes -for ((i = 1; i < worker_num; i++)); do - node_i=${nodes[$i]} - echo "Starting WORKER $i at $node_i" - srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ - env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ - ${CONDA_BIN_PATH}ray start --address "$address_head" \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & -done -sleep 10 - - -# =================== RL Config =================== -# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_prompt_length=$((1024 * 4)) -max_response_length=$((1024 * 64)) -enable_overlong_buffer=False -overlong_buffer_len=$((1024 * 12)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" -rollout_dtype="float16" - -enable_filter_groups=False -filter_groups_metric=acc -max_num_gen_batches=10 -train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n -gen_prompt_bsz=$((train_prompt_bsz * 1)) -n_resp_per_prompt=16 -train_prompt_mini_bsz=256 # model grad update batchsize - -# Algorithm -temperature=1.2 -val_temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout - -# Training config -sp_size=16 # Reduced from 32 to reduce memory pressure -gen_tp=4 -gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure -infer_micro_batch_size=null -train_micro_batch_size=null -use_dynamic_bsz=True -actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow -infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow -offload=True - -# =================== Start RL training =================== -"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ - --config-path=config \ - --config-name="dapo_fsdp_config.yaml" \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.prompt_key=prompt \ - data.truncation='right' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ - data.gen_batch_size=${gen_prompt_bsz} \ - actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ - actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.actor.strategy="fsdp2" \ - actor_rollout_ref.actor.optim.lr=5e-7 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.optim.warmup_style=constant \ - actor_rollout_ref.actor.optim.min_lr_ratio=0. \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ - actor_rollout_ref.actor.entropy_checkpointing=True \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ - actor_rollout_ref.rollout.disable_log_stats=False \ - actor_rollout_ref.rollout.enforce_eager=False \ - actor_rollout_ref.rollout.enable_prefix_caching=True \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ - actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.model.path=$BASE_MODEL \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.rollout.multi_turn.enable=False \ - actor_rollout_ref.rollout.mode="sync" \ - +actor_rollout_ref.model.override_config.attention_dropout=0. \ - +actor_rollout_ref.model.override_config.embd_pdrop=0. \ - +actor_rollout_ref.model.override_config.resid_pdrop=0. \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.model.enable_activation_offload=${offload} \ - actor_rollout_ref.model.use_liger=True \ - reward_model.reward_manager=async_multi_process \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger=['console','wandb'] \ - trainer.project_name=${WANDB_PROJECT} \ - trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=$worker_num \ - trainer.save_freq=1 \ - trainer.test_freq=5 \ - trainer.total_epochs=5 \ - trainer.log_val_generations=0 \ - trainer.resume_mode=auto \ - trainer.max_actor_ckpt_to_keep=3 - # data.id_val_files="$id_val_files" \ - # trainer.log_val_generations=50 \ \ No newline at end of file From b69d7e7dcad847afc4315ad15bf8af40634766bc Mon Sep 17 00:00:00 2001 From: Varad Pimpalkhute Date: Mon, 12 Jan 2026 20:21:07 +0000 Subject: [PATCH 18/20] add parsing timeout --- verl/utils/reward_score/synlogic/operation_verifier.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/verl/utils/reward_score/synlogic/operation_verifier.py b/verl/utils/reward_score/synlogic/operation_verifier.py index 26a795fe..3a2fa810 100644 --- a/verl/utils/reward_score/synlogic/operation_verifier.py +++ b/verl/utils/reward_score/synlogic/operation_verifier.py @@ -11,10 +11,10 @@ class OperationVerifier(Verifier): """ def verify(self, data: Data, test_answer: str): try: - @timeout_limit(seconds=20) + @timeout_limit(seconds=25) def _verify_with_timeout(): - ground_truth = math_verify.parse(data.answer, parsing_timeout=None) - parsed_answer = math_verify.parse(test_answer, parsing_timeout=None) + ground_truth = math_verify.parse(data.answer, parsing_timeout=10) + parsed_answer = math_verify.parse(test_answer, parsing_timeout=10) if parsed_answer is None: return False From 47b4cccee65daf7bd338c89f59ba736743c36516 Mon Sep 17 00:00:00 2001 From: Varad Pimpalkhute Date: Thu, 29 Jan 2026 22:08:20 +0000 Subject: [PATCH 19/20] fix critical bugs in reward functions --- scripts/train/k2p/k2p_hero_grpo_stage2.sh | 370 +++++++++++++++++ .../train/k2p/k2p_hero_grpo_stage2_fixed.sh | 369 +++++++++++++++++ .../k2p/k2p_hero_grpo_stage2_fixed_n64.sh | 369 +++++++++++++++++ .../k2p/k2p_hero_grpo_stage2_profiling.sh | 377 ++++++++++++++++++ scripts/train/k2p_hero_grpo_newData.sh | 33 +- verl/utils/reward_score/__init__.py | 24 +- verl/utils/reward_score/arcagi.py | 10 +- verl/utils/reward_score/codeio.py | 11 +- verl/utils/reward_score/graph_dataset.py | 6 - verl/utils/reward_score/ifbench/__init__.py | 11 +- verl/utils/reward_score/ifeval/__init__.py | 10 +- .../reward_score/math_llm_judge/__init__.py | 24 +- verl/utils/reward_score/naive_dapo.py | 1 - verl/utils/reward_score/prime_math/grader.py | 25 +- verl/utils/reward_score/puzzles_dataset.py | 6 - .../reward_score/reasoning_gym/__init__.py | 45 +-- .../reward_score/stem_llm_judge/__init__.py | 18 +- .../synlogic/arrow_maze_verifier.py | 66 +-- .../synlogic/boolean_expressions_verifier.py | 2 - .../synlogic/campsite_verifier.py | 2 - .../synlogic/dyck_language_errors_verifier.py | 2 - ...dyck_language_reasoning_errors_verifier.py | 2 - .../synlogic/dyck_language_verifier.py | 2 - .../synlogic/game_of_buggy_tables_verifier.py | 2 - .../synlogic/goods_exchange_verifier.py | 2 - .../synlogic/math_path_verifier.py | 2 - .../synlogic/minesweeper_verifier.py | 2 - .../synlogic/norinori_verifier.py | 2 - .../synlogic/number_wall_verifier.py | 2 - .../reward_score/synlogic/numbrix_verifier.py | 2 - .../synlogic/object_counting_verifier.py | 2 - .../synlogic/object_properties_verifier.py | 2 - .../synlogic/operation_verifier.py | 2 - .../synlogic/skyscraper_puzzle_verifier.py | 2 - .../synlogic/space_reasoning_tree_verifier.py | 2 - .../synlogic/space_reasoning_verifier.py | 2 - .../star_placement_puzzle_verifier.py | 2 - .../synlogic/time_sequence_verifier.py | 2 - .../synlogic/web_of_lies_verifier.py | 2 - .../synlogic/word_sorting_mistake_verifier.py | 2 - .../synlogic/word_sorting_verifier.py | 2 - .../synlogic/wordscapes_verifier.py | 2 - verl/utils/reward_score/zebra_puzzle.py | 6 - verl/workers/reward_manager/async_mp.py | 39 ++ 44 files changed, 1626 insertions(+), 242 deletions(-) create mode 100644 scripts/train/k2p/k2p_hero_grpo_stage2.sh create mode 100644 scripts/train/k2p/k2p_hero_grpo_stage2_fixed.sh create mode 100644 scripts/train/k2p/k2p_hero_grpo_stage2_fixed_n64.sh create mode 100644 scripts/train/k2p/k2p_hero_grpo_stage2_profiling.sh diff --git a/scripts/train/k2p/k2p_hero_grpo_stage2.sh b/scripts/train/k2p/k2p_hero_grpo_stage2.sh new file mode 100644 index 00000000..2d37a7e2 --- /dev/null +++ b/scripts/train/k2p/k2p_hero_grpo_stage2.sh @@ -0,0 +1,370 @@ +#!/bin/bash +#SBATCH --job-name=grpo-stage2-k2pRL-dataMix2 +#SBATCH --nodes=64 +#SBATCH --ntasks=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=higherprio +#SBATCH --exclude=azure-uk-hpc-H200-instance-337 + +# commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410] +# job name: grpo-stage2-k2pRL-easy50k-7domains +# job name: grpo-k2p-newFiltered-64k-fullData-finalInstruct + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="grpo-stage2-k2pRL-dataMix2-417843" # Fill in the checkpoint directory name to resume from, otherwise from scratch +# export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +# export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-036:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-058:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty + +# =================== Cluster Environment =================== +# export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ +export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl//bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800000 +export RAY_memory_usage_threshold=0.95 # Increase Ray memory threshold before killing workers +export OMPI_MCA_coll_hcoll_enable=0 \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 \ +NCCL_NVLS_ENABLE=0 + + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== + +# Training Data Configuration +DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_2" +train_file_list=() +id_val_file_list=() + +iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet" + +# List of datasets to include (filename only) +# Comment out lines to exclude specific datasets +dataset_names=( + "codegen__deduped_leetcode2k_2.4k.parquet" + "codegen__deduped_livecodebench_599.parquet" + "codegen__deduped_primeintellect_9.6k.parquet" + "codegen__deduped_taco_11.1k.parquet" + "ifbench__fixed_85.6k.parquet" + "simulation__codeio_fixed_12.1k.parquet" + "logic__arcagi1_297.parquet" + "logic__arcagi2_653.parquet" + "logic__barc_3.4k.parquet" + "logic__graph_logical_dataset_1.4k.parquet" + "logic__ordering_puzzle_dataset_2.9k.parquet" + "logic__reasoning_gym_40.6k.parquet" + "logic__synlogic_12.1k.parquet" + "logic__zebra_puzzle_dataset_5.0k.parquet" + "math__combined_118.2k.part1.parquet" + "math__combined_118.2k.part2.parquet" + "omni_math_4.43k.parquet" + "stem__nemotron_13.3k.parquet" + "stem__web_31.7k.parquet" + "table__hitab_7.4k.parquet" + "table__multihier_2.9k.parquet" +) + +echo "Collecting training files from ${DATA_MIX_DIR}..." + +# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions" "easy_questions" +for dataset in "${dataset_names[@]}"; do + for subdir in "main_questions" "131k_context_questions" "impossible_questions"; do + file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" + if [ -f "$file_path" ]; then + echo "Adding: $file_path" + train_file_list+=("'$file_path'") + fi + done +done + +# for dataset in "${dataset_names[@]}"; do +# for subdir in "131k_context_questions"; do +# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" +# if [ -f "$file_path" ]; then +# echo "Adding: $file_path" +# id_val_file_list+=("'$file_path'") +# fi +# done +# done +# id_val_file_list+=("'$iq400_path'") + +# Join with comma to form Python list string +IFS=, +train_files="[${train_file_list[*]}]" +# id_val_files="[${id_val_file_list[*]}]" +unset IFS + +echo "Total training files found: ${#train_file_list[@]}" +# echo "Total ID validation files found: ${#id_val_file_list[@]}" + +# Test Data Configuration +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet +# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet + +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# Instruction follow (test) +if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet + +# Focused data mixture (math, code, stem) +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Full data mixture (uncomment to use) +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}', '${reasoninggym_test_path}' + + +# =================== Model =================== +# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k +# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} +# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 64)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 12)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=256 # model grad update batchsize + +# Algorithm +temperature=1.2 +val_temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=16 # Disable sequence parallelism +gen_tp=4 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=${offload} \ + actor_rollout_ref.model.use_liger=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.num_processes=64 \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=5 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 + # trainer.log_val_generations=50 diff --git a/scripts/train/k2p/k2p_hero_grpo_stage2_fixed.sh b/scripts/train/k2p/k2p_hero_grpo_stage2_fixed.sh new file mode 100644 index 00000000..4d0070e0 --- /dev/null +++ b/scripts/train/k2p/k2p_hero_grpo_stage2_fixed.sh @@ -0,0 +1,369 @@ +#!/bin/bash +#SBATCH --job-name=grpo-stage2-k2pRL-dataMix2-fixed +#SBATCH --nodes=64 +#SBATCH --ntasks=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=higherprio +#SBATCH --exclude=azure-uk-hpc-H200-instance-337 + +# commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410] +# job name: grpo-stage2-k2pRL-easy50k-7domains +# job name: grpo-k2p-newFiltered-64k-fullData-finalInstruct + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="grpo-stage2-k2pRL-dataMix2-417843" # Fill in the checkpoint directory name to resume from, otherwise from scratch +# export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +# export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-117:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-233:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty + +# =================== Cluster Environment =================== +# export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ +export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl//bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800000 +export OMPI_MCA_coll_hcoll_enable=0 \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 \ +NCCL_NVLS_ENABLE=0 + + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== + +# Training Data Configuration +DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_2" +train_file_list=() +id_val_file_list=() + +iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet" + +# List of datasets to include (filename only) +# Comment out lines to exclude specific datasets +dataset_names=( + "codegen__deduped_leetcode2k_2.4k.parquet" + "codegen__deduped_livecodebench_599.parquet" + "codegen__deduped_primeintellect_9.6k.parquet" + "codegen__deduped_taco_11.1k.parquet" + "ifbench__fixed_85.6k.parquet" + "simulation__codeio_fixed_12.1k.parquet" + "logic__arcagi1_297.parquet" + "logic__arcagi2_653.parquet" + "logic__barc_3.4k.parquet" + "logic__graph_logical_dataset_1.4k.parquet" + "logic__ordering_puzzle_dataset_2.9k.parquet" + "logic__reasoning_gym_40.6k.parquet" + "logic__synlogic_12.1k.parquet" + "logic__zebra_puzzle_dataset_5.0k.parquet" + "math__combined_118.2k.part1.parquet" + "math__combined_118.2k.part2.parquet" + "omni_math_4.43k.parquet" + "stem__nemotron_13.3k.parquet" + "stem__web_31.7k.parquet" + "table__hitab_7.4k.parquet" + "table__multihier_2.9k.parquet" +) + +echo "Collecting training files from ${DATA_MIX_DIR}..." + +# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions" "easy_questions" +for dataset in "${dataset_names[@]}"; do + for subdir in "main_questions" "131k_context_questions" "impossible_questions"; do + file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" + if [ -f "$file_path" ]; then + echo "Adding: $file_path" + train_file_list+=("'$file_path'") + fi + done +done + +# for dataset in "${dataset_names[@]}"; do +# for subdir in "131k_context_questions"; do +# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" +# if [ -f "$file_path" ]; then +# echo "Adding: $file_path" +# id_val_file_list+=("'$file_path'") +# fi +# done +# done +# id_val_file_list+=("'$iq400_path'") + +# Join with comma to form Python list string +IFS=, +train_files="[${train_file_list[*]}]" +# id_val_files="[${id_val_file_list[*]}]" +unset IFS + +echo "Total training files found: ${#train_file_list[@]}" +# echo "Total ID validation files found: ${#id_val_file_list[@]}" + +# Test Data Configuration +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet +# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet + +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# Instruction follow (test) +if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet + +# Focused data mixture (math, code, stem) +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Full data mixture (uncomment to use) +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}', '${reasoninggym_test_path}' + + +# =================== Model =================== +# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k +# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} +# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 64)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 12)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=256 # model grad update batchsize + +# Algorithm +temperature=1.2 +val_temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=16 # Disable sequence parallelism +gen_tp=4 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=48000 \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=${offload} \ + actor_rollout_ref.model.use_liger=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.num_processes=64 \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=5 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 + # trainer.log_val_generations=50 diff --git a/scripts/train/k2p/k2p_hero_grpo_stage2_fixed_n64.sh b/scripts/train/k2p/k2p_hero_grpo_stage2_fixed_n64.sh new file mode 100644 index 00000000..d6524857 --- /dev/null +++ b/scripts/train/k2p/k2p_hero_grpo_stage2_fixed_n64.sh @@ -0,0 +1,369 @@ +#!/bin/bash +#SBATCH --job-name=grpo-stage2-k2pRL-dataMix2-n64-fixed +#SBATCH --nodes=64 +#SBATCH --ntasks=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=higherprio +#SBATCH --exclude=azure-uk-hpc-H200-instance-337 + +# commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410] +# job name: grpo-stage2-k2pRL-easy50k-7domains +# job name: grpo-k2p-newFiltered-64k-fullData-finalInstruct + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +# export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +# export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-036:8000,http://azure-uk-hpc-H200-instance-061:8000,http://azure-uk-hpc-H200-instance-062:8000,http://azure-uk-hpc-H200-instance-175:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-058:8000,http://azure-uk-hpc-H200-instance-176:8000,http://azure-uk-hpc-H200-instance-387:8000,http://azure-uk-hpc-H200-instance-388:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty + +# =================== Cluster Environment =================== +# export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ +export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl//bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800000 +export OMPI_MCA_coll_hcoll_enable=0 \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 \ +NCCL_NVLS_ENABLE=0 + + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== + +# Training Data Configuration +DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_2" +train_file_list=() +id_val_file_list=() + +iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet" + +# List of datasets to include (filename only) +# Comment out lines to exclude specific datasets +dataset_names=( + "codegen__deduped_leetcode2k_2.4k.parquet" + "codegen__deduped_livecodebench_599.parquet" + "codegen__deduped_primeintellect_9.6k.parquet" + "codegen__deduped_taco_11.1k.parquet" + "ifbench__fixed_85.6k.parquet" + "simulation__codeio_fixed_12.1k.parquet" + "logic__arcagi1_297.parquet" + "logic__arcagi2_653.parquet" + "logic__barc_3.4k.parquet" + "logic__graph_logical_dataset_1.4k.parquet" + "logic__ordering_puzzle_dataset_2.9k.parquet" + "logic__reasoning_gym_40.6k.parquet" + "logic__synlogic_12.1k.parquet" + "logic__zebra_puzzle_dataset_5.0k.parquet" + "math__combined_118.2k.part1.parquet" + "math__combined_118.2k.part2.parquet" + "omni_math_4.43k.parquet" + "stem__nemotron_13.3k.parquet" + "stem__web_31.7k.parquet" + "table__hitab_7.4k.parquet" + "table__multihier_2.9k.parquet" +) + +echo "Collecting training files from ${DATA_MIX_DIR}..." + +# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions" "easy_questions" +for dataset in "${dataset_names[@]}"; do + for subdir in "main_questions" "131k_context_questions" "impossible_questions"; do + file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" + if [ -f "$file_path" ]; then + echo "Adding: $file_path" + train_file_list+=("'$file_path'") + fi + done +done + +# for dataset in "${dataset_names[@]}"; do +# for subdir in "131k_context_questions"; do +# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" +# if [ -f "$file_path" ]; then +# echo "Adding: $file_path" +# id_val_file_list+=("'$file_path'") +# fi +# done +# done +# id_val_file_list+=("'$iq400_path'") + +# Join with comma to form Python list string +IFS=, +train_files="[${train_file_list[*]}]" +# id_val_files="[${id_val_file_list[*]}]" +unset IFS + +echo "Total training files found: ${#train_file_list[@]}" +# echo "Total ID validation files found: ${#id_val_file_list[@]}" + +# Test Data Configuration +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet +# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet + +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# Instruction follow (test) +if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet + +# Focused data mixture (math, code, stem) +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Full data mixture (uncomment to use) +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}', '${reasoninggym_test_path}' + + +# =================== Model =================== +# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k +# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} +# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 64)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 12)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=64 +train_prompt_mini_bsz=256 # model grad update batchsize + +# Algorithm +temperature=1.2 +val_temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=16 # Disable sequence parallelism +gen_tp=4 +gen_max_num_seqs=1024 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=60000 \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=${offload} \ + actor_rollout_ref.model.use_liger=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.num_processes=128 \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=5 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 + # trainer.log_val_generations=50 diff --git a/scripts/train/k2p/k2p_hero_grpo_stage2_profiling.sh b/scripts/train/k2p/k2p_hero_grpo_stage2_profiling.sh new file mode 100644 index 00000000..5586bf1a --- /dev/null +++ b/scripts/train/k2p/k2p_hero_grpo_stage2_profiling.sh @@ -0,0 +1,377 @@ +#!/bin/bash +#SBATCH --job-name=grpo-stage2-k2pRL-dataMix2-profiling +#SBATCH --nodes=64 +#SBATCH --ntasks=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=higherprio +#SBATCH --exclude=azure-uk-hpc-H200-instance-337 + +# commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410] +# job name: grpo-stage2-k2pRL-easy50k-7domains +# job name: grpo-k2p-newFiltered-64k-fullData-finalInstruct + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +# export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +# export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-036:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-058:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/lustrefs/users/varad.pimpalkhute/anaconda3/envs/sync-rl-v5/bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800000 +export RAY_memory_usage_threshold=0.95 # Increase Ray memory threshold before killing workers +export OMPI_MCA_coll_hcoll_enable=0 \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 \ +NCCL_NVLS_ENABLE=0 + + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== + +# Training Data Configuration +DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_2" +train_file_list=() +id_val_file_list=() + +iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet" + +# List of datasets to include (filename only) +# Comment out lines to exclude specific datasets +dataset_names=( + "codegen__deduped_leetcode2k_2.4k.parquet" + "codegen__deduped_livecodebench_599.parquet" + "codegen__deduped_primeintellect_9.6k.parquet" + "codegen__deduped_taco_11.1k.parquet" + "ifbench__fixed_85.6k.parquet" + "simulation__codeio_fixed_12.1k.parquet" + "logic__arcagi1_297.parquet" + "logic__arcagi2_653.parquet" + "logic__barc_3.4k.parquet" + "logic__graph_logical_dataset_1.4k.parquet" + "logic__ordering_puzzle_dataset_2.9k.parquet" + "logic__reasoning_gym_40.6k.parquet" + "logic__synlogic_12.1k.parquet" + "logic__zebra_puzzle_dataset_5.0k.parquet" + "math__combined_118.2k.part1.parquet" + "math__combined_118.2k.part2.parquet" + "omni_math_4.43k.parquet" + "stem__nemotron_13.3k.parquet" + "stem__web_31.7k.parquet" + "table__hitab_7.4k.parquet" + "table__multihier_2.9k.parquet" +) + +echo "Collecting training files from ${DATA_MIX_DIR}..." + +# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions" "easy_questions" +for dataset in "${dataset_names[@]}"; do + for subdir in "main_questions" "131k_context_questions" "impossible_questions"; do + file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" + if [ -f "$file_path" ]; then + echo "Adding: $file_path" + train_file_list+=("'$file_path'") + fi + done +done + +# for dataset in "${dataset_names[@]}"; do +# for subdir in "131k_context_questions"; do +# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" +# if [ -f "$file_path" ]; then +# echo "Adding: $file_path" +# id_val_file_list+=("'$file_path'") +# fi +# done +# done +# id_val_file_list+=("'$iq400_path'") + +# Join with comma to form Python list string +IFS=, +train_files="[${train_file_list[*]}]" +# id_val_files="[${id_val_file_list[*]}]" +unset IFS + +echo "Total training files found: ${#train_file_list[@]}" +# echo "Total ID validation files found: ${#id_val_file_list[@]}" + +# Test Data Configuration +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet +# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet + +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# Instruction follow (test) +if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet + +# Focused data mixture (math, code, stem) +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Full data mixture (uncomment to use) +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}', + + +# =================== Model =================== +# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k +# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} +# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491" + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 64)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 12)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=256 # model grad update batchsize + +# Algorithm +temperature=1.2 +val_temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=16 # Disable sequence parallelism +gen_tp=4 +gen_max_num_seqs=256 # REDUCED from 1024 to fix OOM with long sequences +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.enable_prefix_caching=False \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=${offload} \ + actor_rollout_ref.model.use_liger=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.num_processes=48 \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=5 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 \ + global_profiler.tool=torch \ + global_profiler.steps=[1,2,3,4,5,6,7,8,9] \ + actor_rollout_ref.actor.profiler.enable=True \ + actor_rollout_ref.actor.profiler.all_ranks=True \ + actor_rollout_ref.rollout.profiler.enable=True \ + actor_rollout_ref.rollout.profiler.all_ranks=True + + + # trainer.log_val_generations=50 diff --git a/scripts/train/k2p_hero_grpo_newData.sh b/scripts/train/k2p_hero_grpo_newData.sh index a387f442..436215c8 100644 --- a/scripts/train/k2p_hero_grpo_newData.sh +++ b/scripts/train/k2p_hero_grpo_newData.sh @@ -11,7 +11,10 @@ #SBATCH --exclusive #SBATCH --time=720:00:00 #SBATCH --partition=higherprio + # commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410] +# job name: grpo-stage2-k2pRL-easy50k-7domains +# job name: grpo-k2p-newFiltered-64k-fullData-finalInstruct # =================== Frequently Used Variables =================== RESUME_CKPT_DIR_NAME="grpo-stage2-k2pRL-easy50k-7domains-415354" # Fill in the checkpoint directory name to resume from, otherwise from scratch @@ -78,23 +81,23 @@ dataset_names=( "codegen__deduped_primeintellect_9.6k.parquet" "codegen__deduped_taco_11.1k.parquet" "ifbench__fixed_85.6k.parquet" + "simulation__codeio_fixed_12.1k.parquet" + "logic__arcagi1_297.parquet" + "logic__arcagi2_653.parquet" + "logic__barc_3.4k.parquet" + "logic__graph_logical_dataset_1.4k.parquet" + "logic__ordering_puzzle_dataset_2.9k.parquet" + "logic__reasoning_gym_40.6k.parquet" + "logic__synlogic_12.1k.parquet" + "logic__zebra_puzzle_dataset_5.0k.parquet" "math__combined_118.2k.part1.parquet" "math__combined_118.2k.part2.parquet" - "omni_math_4.43k_dedup.parquet" + "omni_math_4.43k.parquet" "stem__nemotron_13.3k.parquet" "stem__web_31.7k.parquet" "table__hitab_7.4k.parquet" "table__multihier_2.9k.parquet" ) -# "simulation__codeio_fixed_12.1k.parquet" -# "logic__arcagi1_297.parquet" -# "logic__arcagi2_653.parquet" -# "logic__barc_3.4k.parquet" -# "logic__graph_logical_dataset_1.4k.parquet" -# "logic__ordering_puzzle_dataset_2.9k.parquet" -# "logic__reasoning_gym_40.6k.parquet" -# "logic__synlogic_12.1k.parquet" -# "logic__zebra_puzzle_dataset_5.0k.parquet" echo "Collecting training files from ${DATA_MIX_DIR}..." @@ -176,7 +179,7 @@ test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${a # BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface # BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k -BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface +# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface # =================== Logging =================== WANDB_PROJECT=k2plus_rl @@ -290,9 +293,9 @@ offload=True actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ actor_rollout_ref.actor.clip_ratio_c=10.0 \ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=48000 \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ actor_rollout_ref.actor.strategy="fsdp2" \ - actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr= 5e-7 \ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ actor_rollout_ref.actor.optim.weight_decay=0.1 \ actor_rollout_ref.actor.optim.warmup_style=constant \ @@ -345,7 +348,7 @@ offload=True actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.model.enable_activation_offload=${offload} \ actor_rollout_ref.model.use_liger=True \ - reward_model.reward_manager=dapo \ + reward_model.reward_manager=async_multi_process \ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ reward_model.overlong_buffer.len=${overlong_buffer_len} \ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ @@ -361,4 +364,4 @@ offload=True trainer.total_epochs=5 \ trainer.resume_mode=auto \ trainer.max_actor_ckpt_to_keep=3 - # trainer.log_val_generations=50 \ \ No newline at end of file + # trainer.log_val_generations=50 \ No newline at end of file diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index bb014b13..a1b07cd5 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -13,6 +13,8 @@ # limitations under the License. # from . import gsm8k, math, prime_math, prime_code +import time + from verl.utils.import_utils import deprecated @@ -40,6 +42,8 @@ def default_compute_score( Raises: NotImplementedError: If the reward function is not implemented for the given data source. """ + start_time = time.perf_counter() + # Handle extra_info format robustly reward_metric = None if extra_info and isinstance(extra_info, dict): @@ -209,12 +213,28 @@ def default_compute_score( else: raise NotImplementedError(f"Reward function is not implemented for {data_source=}") + # Calculate elapsed time + elapsed = time.perf_counter() - start_time + + # Log slow samples (>1 second) + if elapsed > 10.0: + print(f"[SLOW REWARD] {data_source}: {elapsed:.2f}s") + + # Return result with timing metadata + # Always ensure "score" and "acc" keys are present for consistency if isinstance(res, dict): + res["_reward_time"] = elapsed + res["_data_source"] = data_source + # Ensure "acc" is present if not already + if "acc" not in res: + res["acc"] = res.get("score", 0.0) return res elif isinstance(res, int | float | bool): - return float(res) + score = float(res) + return {"score": score, "acc": score, "_reward_time": elapsed, "_data_source": data_source} else: - return float(res[0]) + score = float(res[0]) + return {"score": score, "acc": score, "_reward_time": elapsed, "_data_source": data_source} @deprecated("verl.utils.reward_score.default_compute_score") diff --git a/verl/utils/reward_score/arcagi.py b/verl/utils/reward_score/arcagi.py index 198db591..b30e0d21 100644 --- a/verl/utils/reward_score/arcagi.py +++ b/verl/utils/reward_score/arcagi.py @@ -1,7 +1,6 @@ import re import ast import numpy as np -from verl.utils.py_functional import timeout_limit def extract_solution(solution_str): @@ -94,20 +93,13 @@ def compare_solutions_with_padding(generated_output, correct_output, pad_value=- def compute_score( model_output: str, ground_truth: np.ndarray, extra_info: any = None ) -> float: - @timeout_limit(seconds=10) - def _compute_score_with_timeout(): + try: model_output_str = str(model_output) final_answer = extract_solution(model_output_str) is_correct, correct_percentage = compare_solutions_with_padding( final_answer, ground_truth ) return {"score": is_correct, "acc": is_correct} - - try: - return _compute_score_with_timeout() - except TimeoutError: - print("Computation timed out in arcagi") - return {"score": 0.0, "acc": 0.0} except Exception as e: print(f"Error in compute_score in arcagi: {e}") return {"score": 0.0, "acc": 0.0} diff --git a/verl/utils/reward_score/codeio.py b/verl/utils/reward_score/codeio.py index 67b8b19b..5a84c719 100644 --- a/verl/utils/reward_score/codeio.py +++ b/verl/utils/reward_score/codeio.py @@ -2,7 +2,6 @@ import ast import re from typing import Dict, Any, Tuple, List -from verl.utils.py_functional import timeout_limit def normalize(obj: Any) -> Any: """ @@ -140,16 +139,10 @@ def compute_score(model_output: str, ground_truth: str, extra_info: any = None) """ Compute score dict for evaluation harness. """ - @timeout_limit(seconds=10) - def _compute_score_with_timeout(): - correct, _ = check_accuracy(str(model_output), str(ground_truth), any_order=False) - return {"score": correct, "acc": correct} try: - return _compute_score_with_timeout() - except TimeoutError: - print("Computation timed out in codeio") - return {"score": False, "acc": False} + correct, _ = check_accuracy(str(model_output), str(ground_truth), any_order=False) + return {"score": correct, "acc": correct} except Exception as e: print(f"Error in compute_score in codeio: {e}") return {"score": False, "acc": False} diff --git a/verl/utils/reward_score/graph_dataset.py b/verl/utils/reward_score/graph_dataset.py index 831b4821..51983e33 100644 --- a/verl/utils/reward_score/graph_dataset.py +++ b/verl/utils/reward_score/graph_dataset.py @@ -3,8 +3,6 @@ import ast import operator -from verl.utils.py_functional import timeout_limit - def extract_solution(solution_str): @@ -31,7 +29,6 @@ def compute_score( timeout: maximum time in seconds to allow for computation """ - @timeout_limit(seconds=timeout) def _compute_with_timeout(): if not isinstance(ground_truth, str): ground_truth_str = str(ground_truth) @@ -56,9 +53,6 @@ def _compute_with_timeout(): try: score = _compute_with_timeout() - except TimeoutError: - print("Computation timed out in graph_dataset") - score = 0.0 except Exception as e: print(f"Error in compute_score in graph_dataset: {e}") score = 0.0 diff --git a/verl/utils/reward_score/ifbench/__init__.py b/verl/utils/reward_score/ifbench/__init__.py index 1369593e..12f07b96 100644 --- a/verl/utils/reward_score/ifbench/__init__.py +++ b/verl/utils/reward_score/ifbench/__init__.py @@ -2,7 +2,6 @@ import json import numpy as np -from verl.utils.py_functional import timeout_limit from .instructions_registry import INSTRUCTION_DICT @@ -19,8 +18,8 @@ def compute_score(solution_str, ground_truth, extra_info=None): Returns: dict: {"score": float, "acc": bool} """ - @timeout_limit(seconds=30) - def _compute_score_with_timeout(): + + try: # Strip off any thinking section if "" in solution_str: answer = solution_str.split("", 1)[1].strip() @@ -68,12 +67,6 @@ def _compute_score_with_timeout(): # Return 1.0 if all constraints are satisfied, 0.0 otherwise score = 1.0 if all(results) else 0.0 return {"score": score, "acc": score == 1.0} - - try: - return _compute_score_with_timeout() - except TimeoutError: - print("Computation timed out in ifbench") - return {"score": 0.0, "acc": False} except Exception as e: print(f"Error in compute_score in ifbench: {e}") return {"score": 0.0, "acc": False} diff --git a/verl/utils/reward_score/ifeval/__init__.py b/verl/utils/reward_score/ifeval/__init__.py index 3d651190..aaf4dcea 100644 --- a/verl/utils/reward_score/ifeval/__init__.py +++ b/verl/utils/reward_score/ifeval/__init__.py @@ -1,6 +1,5 @@ from verl.utils.reward_score.ifeval import instructions_registry import numpy as np -from verl.utils.py_functional import timeout_limit def compute_score(solution_str, ground_truth, extra_info): """The scoring function for IFEval. @@ -14,8 +13,7 @@ def compute_score(solution_str, ground_truth, extra_info): format_score: the score for the format score: the score for the correct answer """ - @timeout_limit(seconds=60) - def _compute_score_with_timeout(): + try: if "" in solution_str: answer = solution_str.split("")[1] else: @@ -49,12 +47,6 @@ def _compute_score_with_timeout(): "score": all(is_following_list), "acc": all(is_following_list), } - - try: - return _compute_score_with_timeout() - except TimeoutError: - print("Computation timed out in ifbench") - return {"score": 0.0, "acc": False} except Exception as e: print(f"Error in compute_score in ifbench: {e}") return {"score": 0.0, "acc": False} \ No newline at end of file diff --git a/verl/utils/reward_score/math_llm_judge/__init__.py b/verl/utils/reward_score/math_llm_judge/__init__.py index f725986f..f252aae8 100644 --- a/verl/utils/reward_score/math_llm_judge/__init__.py +++ b/verl/utils/reward_score/math_llm_judge/__init__.py @@ -48,9 +48,6 @@ from . import math_normalize from .grader import math_equal -import requests -from verl.utils.py_functional import timeout_limit - # import math_normalize # from grader import math_equal @@ -376,16 +373,27 @@ def match_answer(response): def llm_check_answer(model_output: str, ground_truth: str, question: str) -> bool: # use llm to check if the answer is correct + # Supports multiple endpoints for load balancing - separate URLs with commas + # e.g., MATH_LLM_JUDGE_URL="http://host1:8000,http://host2:8000,http://host3:8000" - # url = "http://176.56.200.81:30000/v1/chat/completions" import os - url_base = os.getenv("MATH_LLM_JUDGE_URL") - if not url_base: + import random + + url_base_str = os.getenv("MATH_LLM_JUDGE_URL") + if not url_base_str: raise ValueError("MATH_LLM_JUDGE_URL is not set") + + # Support multiple endpoints separated by commas + endpoints = [url.strip() for url in url_base_str.split(",") if url.strip()] + if not endpoints: + raise ValueError("MATH_LLM_JUDGE_URL contains no valid endpoints") + + # Randomly select an endpoint for load balancing + url_base = random.choice(endpoints) url = url_base.rstrip("/") + "/v1/chat/completions" - + prompt = input_template.format(QUESTION=question, STUDENT_ANSWER=model_output, REFERENCE_ANSWER=ground_truth) - + data = { "model": "openai/gpt-oss-120b", "messages": [{"role": "user", "content": prompt}], diff --git a/verl/utils/reward_score/naive_dapo.py b/verl/utils/reward_score/naive_dapo.py index 739b9c9e..af4ba8be 100644 --- a/verl/utils/reward_score/naive_dapo.py +++ b/verl/utils/reward_score/naive_dapo.py @@ -24,7 +24,6 @@ from .prime_math import math_normalize from .prime_math.grader import math_equal -from verl.utils.py_functional import timeout_limit # Constants for normalization diff --git a/verl/utils/reward_score/prime_math/grader.py b/verl/utils/reward_score/prime_math/grader.py index 72bb749f..c2e0ed7e 100644 --- a/verl/utils/reward_score/prime_math/grader.py +++ b/verl/utils/reward_score/prime_math/grader.py @@ -102,9 +102,6 @@ from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr -# verl related -from verl.utils.py_functional import timeout_limit - def is_digit(s): try: @@ -325,11 +322,7 @@ def symbolic_equal(a, b, tolerance, timeout=10.0): def _parse(s): for f in [parse_expr, parse_latex]: try: - with timeout_limit(seconds=timeout): - return f(s) - except TimeoutError: - print(f"Parsing timed out for {s}") - continue + return f(s) except Exception: continue return s @@ -338,22 +331,14 @@ def _parse(s): b = _parse(b) try: - with timeout_limit(seconds=timeout): - if simplify(a - b) == 0: - return True - except TimeoutError: - print(f"Simplification timed out for {a} - {b}") - pass + if simplify(a - b) == 0: + return True except Exception: pass try: - with timeout_limit(seconds=timeout): - if isclose(N(a), N(b), rel_tol=tolerance): - return True - except TimeoutError: - print(f"Numerical evaluation timed out for {a}, {b}") - pass + if isclose(N(a), N(b), rel_tol=tolerance): + return True except Exception: pass return False diff --git a/verl/utils/reward_score/puzzles_dataset.py b/verl/utils/reward_score/puzzles_dataset.py index 23401e77..ca68c9dd 100644 --- a/verl/utils/reward_score/puzzles_dataset.py +++ b/verl/utils/reward_score/puzzles_dataset.py @@ -3,8 +3,6 @@ import ast import operator -from verl.utils.py_functional import timeout_limit - def extract_solution(solution_str): # Find the answer tag content @@ -77,7 +75,6 @@ def compute_score(solution_str, ground_truth, extra_info: any = None, method='st method: the method to extract the solution timeout: maximum time in seconds to allow for computation """ - @timeout_limit(seconds=timeout) def _compute_with_timeout(): target = ground_truth.tolist() if not isinstance(ground_truth,list) else ground_truth predicted_arrangement = extract_solution(solution_str=solution_str) @@ -103,9 +100,6 @@ def _compute_with_timeout(): score = 0.0 try: score = _compute_with_timeout() - except TimeoutError: - print("Computation timed out in puzzles_dataset") - score = 0.0 except Exception as e: print(f"Error in compute_score in puzzles_dataset: {e}") score = 0.0 diff --git a/verl/utils/reward_score/reasoning_gym/__init__.py b/verl/utils/reward_score/reasoning_gym/__init__.py index 4ba9139f..3d6e2a7a 100644 --- a/verl/utils/reward_score/reasoning_gym/__init__.py +++ b/verl/utils/reward_score/reasoning_gym/__init__.py @@ -1,7 +1,14 @@ import reasoning_gym import json import re -from verl.utils.py_functional import timeout_limit +import gc +from functools import lru_cache + + +@lru_cache(maxsize=64) +def _get_cached_scorer(task): + """Cache scorer objects to avoid recreating them for every sample.""" + return reasoning_gym.get_score_answer_fn(task) def compute_score(solution_str, ground_truth, extra_info=None, item=None): """ @@ -16,7 +23,6 @@ def compute_score(solution_str, ground_truth, extra_info=None, item=None): Returns: dict: {"score": float, "acc": float} """ - @timeout_limit(seconds=10) def _compute_score_with_timeout(): task = None entry = None @@ -60,8 +66,8 @@ def _compute_score_with_timeout(): if not task: raise ValueError("task must be provided in extra_info, item, or ground_truth dict.") - # 4. Get scoring function - scorer = reasoning_gym.get_score_answer_fn(task) + # 4. Get scoring function (cached to avoid recreating for every sample) + scorer = _get_cached_scorer(task) # 5. Get entry if entry is None: @@ -84,30 +90,19 @@ def _compute_score_with_timeout(): # 6. Extract clean answer from solution_str clean_answer = extract_answer_from_solution(solution_str) - - # 7. Scoring with task-specific fixes - debug_log_path = "reasoning_gym_debug.log" + + # 7. Scoring with task-specific fixes (debug logging removed for memory efficiency) try: - with open(debug_log_path, "a", encoding="utf-8") as f: - f.write("[DEBUG] solution_str: {}\n".format(solution_str)) - f.write("[DEBUG] clean_answer: {}\n".format(clean_answer)) - f.write("[DEBUG] ground_truth: {}\n".format(ground_truth)) - f.write("[DEBUG] task: {}\n".format(task)) - f.write("[DEBUG] metadata: {}\n".format(json.dumps(entry.get("metadata", {}), ensure_ascii=False, indent=2))) - - # Get raw score from reasoning_gym using clean answer - raw_score = scorer(answer=clean_answer, entry=entry) - - # Apply task-specific corrections for known issues - corrected_score = apply_task_specific_corrections(task, solution_str, ground_truth, raw_score) - - f.write("[DEBUG] raw_score: {}\n".format(raw_score)) - f.write("[DEBUG] corrected_score: {}\n".format(corrected_score)) - + # Get raw score from reasoning_gym using clean answer + raw_score = scorer(answer=clean_answer, entry=entry) + + # Apply task-specific corrections for known issues + corrected_score = apply_task_specific_corrections(task, solution_str, ground_truth, raw_score) + return {"score": float(corrected_score), "acc": float(corrected_score)} except Exception as e: - with open(debug_log_path, "a", encoding="utf-8") as f: - f.write(f"Error in reasoning gym scoring: {e}\n") + # Only print errors, don't write to file + print(f"Error in reasoning gym scoring: {e}") return {"score": 0.0, "acc": 0.0} try: diff --git a/verl/utils/reward_score/stem_llm_judge/__init__.py b/verl/utils/reward_score/stem_llm_judge/__init__.py index a0a41961..3aeb5680 100644 --- a/verl/utils/reward_score/stem_llm_judge/__init__.py +++ b/verl/utils/reward_score/stem_llm_judge/__init__.py @@ -3,13 +3,15 @@ Prerequisite: - Set env var STEM_LLM_JUDGE_URL to an OpenAI-compatible /v1/chat/completions. + - Supports multiple endpoints for load balancing - separate URLs with commas: + export STEM_LLM_JUDGE_URL="http://host1:8000,http://host2:8000,http://host3:8000" - Launch the service with your preferred model beforehand, e.g. vllm serve TIGER-Lab/general-verifier - export STEM_LLM_JUDGE_URL=http://127.0.0.1:8000/v1/chat/completions + export STEM_LLM_JUDGE_URL=http://127.0.0.1:8000 """ -import os, re, requests +import os, re, random, requests from typing import Tuple # # ------------ Prompt template ------------------------------------------------ @@ -35,9 +37,17 @@ # ------------ Core LLM call -------------------------------------------------- def _llm_judge(question: str, student: str, reference: str, verbose: bool = False) -> bool: - url_base = os.getenv("STEM_LLM_JUDGE_URL") - if not url_base: + url_base_str = os.getenv("STEM_LLM_JUDGE_URL") + if not url_base_str: raise EnvironmentError("STEM_LLM_JUDGE_URL not set") + + # Support multiple endpoints separated by commas for load balancing + endpoints = [url.strip() for url in url_base_str.split(",") if url.strip()] + if not endpoints: + raise EnvironmentError("STEM_LLM_JUDGE_URL contains no valid endpoints") + + # Randomly select an endpoint for load balancing + url_base = random.choice(endpoints) url = url_base.rstrip("/") + "/v1/chat/completions" # prompt = JUDGE_TEMPLATE.format( diff --git a/verl/utils/reward_score/synlogic/arrow_maze_verifier.py b/verl/utils/reward_score/synlogic/arrow_maze_verifier.py index d06cfc63..74f084cd 100644 --- a/verl/utils/reward_score/synlogic/arrow_maze_verifier.py +++ b/verl/utils/reward_score/synlogic/arrow_maze_verifier.py @@ -3,7 +3,6 @@ from .verifier import Verifier from .data import Data import re -from verl.utils.py_functional import timeout_limit class ArrowMazeVerifier(Verifier): """ @@ -44,8 +43,7 @@ def verify(self, data: Data, test_solution_str: str) -> bool: @param test_solution_str: 测试答案字符串 (JSON格式的二维数组) @return: 答案是否正确 """ - @timeout_limit(seconds=60) - def _verify_with_timeout(): + try: test_answer_str = self.extract_answer(test_solution_str) if not test_answer_str: # print("答案为空,验证失败") @@ -60,73 +58,24 @@ def _verify_with_timeout(): # 检查答案是否符合要求 if not self._verify_grid_size(test_answer, question_grid): - # print("答案网格大小与题目不匹配") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案网格大小与题目不匹配" + '\n') - f.write('-'*32 + '\n') return False if not self._verify_number_positions(test_answer, question_grid): - # print("答案中数字位置或值与题目不匹配") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案中数字位置或值与题目不匹配" + '\n') - f.write('-'*32 + '\n') return False if not self._verify_all_blanks_filled(test_answer, question_grid): - # print("答案中有空格未被填满") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案中有空格未被填满" + '\n') - f.write('-'*32 + '\n') return False if not self._verify_arrow_symbols(test_answer): - # print("答案中包含非法箭头符号") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案中包含非法箭头符号" + '\n') - f.write('-'*32 + '\n') return False if not self._verify_prefilled_arrows(test_answer, question_grid): - # print("答案中预填箭头与题目不一致") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案中预填箭头与题目不一致" + '\n') - f.write('-'*32 + '\n') return False if not self._verify_arrow_rays(test_answer): - # print("答案中存在未被射线覆盖的箭头") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案中存在未被射线覆盖的箭头" + '\n') - f.write('-'*32 + '\n') return False if not self._verify_number_rays(test_answer): - # print("答案中数字的射线箭头串总数不符合要求") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("答案中数字的射线箭头串总数不符合要求" + '\n') - f.write('-'*32 + '\n') return False # 所有验证都通过 @@ -134,20 +83,7 @@ def _verify_with_timeout(): return True except Exception as e: - # print(f"验证过程中出错: {e}") - with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: - f.write("test_solution_str: " + test_solution_str + '\n') - f.write("test_answer_str: " + test_answer_str + '\n') - f.write("question_grid: " + str(data.metadata["maze"]) + '\n') - f.write("验证过程中出错" + str(e) + '\n') - f.write('-'*32 + '\n') return False - - try: - return _verify_with_timeout() - except TimeoutError: - print("Verification timed out (ArrowMazeVerifier)") - return False except Exception as e: print(f"Verification error (ArrowMazeVerifier): {e}") return False diff --git a/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py b/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py index d5e6b968..3d84e788 100644 --- a/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py +++ b/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py @@ -1,7 +1,6 @@ import re from .data import Data from .verifier import Verifier -from verl.utils.py_functional import timeout_limit class BooleanExpressionsVerifier(Verifier): """ @@ -9,7 +8,6 @@ class BooleanExpressionsVerifier(Verifier): """ def verify(self, data: Data, test_answer: str): try: - @timeout_limit(seconds=10) def _verify_with_timeout(): test_answer_extracted = self.extract_answer(test_answer) if test_answer_extracted is None: diff --git a/verl/utils/reward_score/synlogic/campsite_verifier.py b/verl/utils/reward_score/synlogic/campsite_verifier.py index 60203268..4387c8f5 100644 --- a/verl/utils/reward_score/synlogic/campsite_verifier.py +++ b/verl/utils/reward_score/synlogic/campsite_verifier.py @@ -3,7 +3,6 @@ import re import ast from typing import List, Set, Tuple, Dict -from verl.utils.py_functional import timeout_limit class CampsiteVerifier(Verifier): @@ -12,7 +11,6 @@ class CampsiteVerifier(Verifier): """ def verify(self, data: Data, test_solution: str): try: - @timeout_limit(seconds=10) def _verify_with_timeout(): test_answer = self.extract_answer(test_solution) original_grid = data.metadata["grid"] diff --git a/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py index 1afb8ff7..3f426a71 100644 --- a/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py +++ b/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py @@ -1,7 +1,6 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import re -from verl.utils.py_functional import timeout_limit class DyckLanguageErrorsVerifier(Verifier): @@ -17,7 +16,6 @@ def verify(self, data: Data, test_answer: str): @return: 回答是否正确的布尔值 """ try: - @timeout_limit(seconds=10) def _verify_with_timeout(): test_answer_extracted = self.extract_answer(test_solution=test_answer) # 获取正确答案 diff --git a/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py index 17c68087..dbf4d7c6 100644 --- a/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py +++ b/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py @@ -1,7 +1,6 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import re -from verl.utils.py_functional import timeout_limit class DyckLanguageReasoningErrorsVerifier(Verifier): @@ -17,7 +16,6 @@ def verify(self, data: Data, test_answer: str): @return: 回答是否正确的布尔值 """ try: - @timeout_limit(seconds=10) def _verify_with_timeout(): test_answer_extracted = self.extract_answer(test_solution=test_answer) # 获取元数据中的正确答案 diff --git a/verl/utils/reward_score/synlogic/dyck_language_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_verifier.py index 20b8125c..1b30d2b1 100644 --- a/verl/utils/reward_score/synlogic/dyck_language_verifier.py +++ b/verl/utils/reward_score/synlogic/dyck_language_verifier.py @@ -1,7 +1,6 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import re -from verl.utils.py_functional import timeout_limit class DyckLanguageVerifier(Verifier): @@ -17,7 +16,6 @@ def verify(self, data: Data, test_answer: str) -> bool: @return: 回答是否正确的布尔值 """ try: - @timeout_limit(seconds=10) def _verify_with_timeout(): # 获取元数据中的完整序列 full_sequence = data.metadata["full_sequence"] diff --git a/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py b/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py index 7993c506..aea853c4 100644 --- a/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py +++ b/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py @@ -1,7 +1,6 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import re -from verl.utils.py_functional import timeout_limit class BuggyTableVerifier(Verifier): """ @@ -28,7 +27,6 @@ def verify(self, data: Data, test_answer: str) -> bool: @return: bool indicating whether the answer is correct """ try: - @timeout_limit(seconds=10) def _verify_with_timeout(): # Extract the expected answer from the Data object expected_answer = data.answer if data and hasattr(data, 'answer') else "" diff --git a/verl/utils/reward_score/synlogic/goods_exchange_verifier.py b/verl/utils/reward_score/synlogic/goods_exchange_verifier.py index fbe1eff2..31392d02 100644 --- a/verl/utils/reward_score/synlogic/goods_exchange_verifier.py +++ b/verl/utils/reward_score/synlogic/goods_exchange_verifier.py @@ -1,7 +1,6 @@ import re from .data import Data from .verifier import Verifier -from verl.utils.py_functional import timeout_limit class GoodsExchangeVerifier(Verifier): """ @@ -16,7 +15,6 @@ def verify(self, data: Data, test_solution: str): @return: 回答是否正确的布尔值 """ try: - @timeout_limit(seconds=10) def _verify_with_timeout(): test_answer = self.extract_answer(test_solution) # 获取元数据中的正确答案 diff --git a/verl/utils/reward_score/synlogic/math_path_verifier.py b/verl/utils/reward_score/synlogic/math_path_verifier.py index 5ebe1d81..d447a5c9 100644 --- a/verl/utils/reward_score/synlogic/math_path_verifier.py +++ b/verl/utils/reward_score/synlogic/math_path_verifier.py @@ -3,7 +3,6 @@ import numpy as np from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END -from verl.utils.py_functional import timeout_limit class MathPathVerifier(Verifier): @@ -19,7 +18,6 @@ def verify(self, data: Data, test_answer: str): @return: 回答是否正确的布尔值 """ try: - @timeout_limit(seconds=10) def _verify_with_timeout(): try: test_answer_extracted = self.extract_answer(test_solution=test_answer) diff --git a/verl/utils/reward_score/synlogic/minesweeper_verifier.py b/verl/utils/reward_score/synlogic/minesweeper_verifier.py index f0f5ae79..8bb0240c 100644 --- a/verl/utils/reward_score/synlogic/minesweeper_verifier.py +++ b/verl/utils/reward_score/synlogic/minesweeper_verifier.py @@ -3,7 +3,6 @@ import re import json from typing import List, Tuple -from verl.utils.py_functional import timeout_limit class MinesweeperVerifier(Verifier): @@ -13,7 +12,6 @@ class MinesweeperVerifier(Verifier): """ def verify(self, data: Data, test_solution: str, **kwargs): try: - @timeout_limit(seconds=10) def _verify_with_timeout(): # 从解答中提取地雷坐标 predicted_mines = self.extract_answer(test_solution) diff --git a/verl/utils/reward_score/synlogic/norinori_verifier.py b/verl/utils/reward_score/synlogic/norinori_verifier.py index c817198d..65793b66 100644 --- a/verl/utils/reward_score/synlogic/norinori_verifier.py +++ b/verl/utils/reward_score/synlogic/norinori_verifier.py @@ -2,7 +2,6 @@ from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import re from collections import defaultdict -from verl.utils.py_functional import timeout_limit class NorinoriVerifier(Verifier): """ @@ -25,7 +24,6 @@ def verify(self, data: Data, test_solution: str): bool -- 答案是否正确 """ try: - @timeout_limit(seconds=10) def _verify_with_timeout(): # 从游戏数据中获取区域网格 region_grid = data.metadata["region_grid"] diff --git a/verl/utils/reward_score/synlogic/number_wall_verifier.py b/verl/utils/reward_score/synlogic/number_wall_verifier.py index aefbedae..918f96b7 100644 --- a/verl/utils/reward_score/synlogic/number_wall_verifier.py +++ b/verl/utils/reward_score/synlogic/number_wall_verifier.py @@ -3,7 +3,6 @@ import re import json from collections import deque -from verl.utils.py_functional import timeout_limit class NumberWallVerifier(Verifier): """ @@ -12,7 +11,6 @@ class NumberWallVerifier(Verifier): """ def verify(self, data: Data, test_solution: str, **kwargs): try: - @timeout_limit(seconds=10) def _verify_with_timeout(): # 提取答案网格 solution_grid = self.extract_answer(test_solution) diff --git a/verl/utils/reward_score/synlogic/numbrix_verifier.py b/verl/utils/reward_score/synlogic/numbrix_verifier.py index d7e9ea21..60fc4aab 100644 --- a/verl/utils/reward_score/synlogic/numbrix_verifier.py +++ b/verl/utils/reward_score/synlogic/numbrix_verifier.py @@ -3,7 +3,6 @@ import re import ast import numpy as np -from verl.utils.py_functional import timeout_limit class NumbrixVerifier(Verifier): """ @@ -12,7 +11,6 @@ class NumbrixVerifier(Verifier): """ def verify(self, data: Data, test_solution: str): try: - @timeout_limit(seconds=10) def _verify_with_timeout(): # 提取答案网格 test_grid = self.extract_answer(test_solution) diff --git a/verl/utils/reward_score/synlogic/object_counting_verifier.py b/verl/utils/reward_score/synlogic/object_counting_verifier.py index 16c1d831..733d2333 100644 --- a/verl/utils/reward_score/synlogic/object_counting_verifier.py +++ b/verl/utils/reward_score/synlogic/object_counting_verifier.py @@ -1,7 +1,6 @@ import re from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END -from verl.utils.py_functional import timeout_limit class ObjectCountingVerifier(Verifier): @@ -10,7 +9,6 @@ class ObjectCountingVerifier(Verifier): """ def verify(self, data: Data, test_answer: str): try: - @timeout_limit(seconds=10) def _verify_with_timeout(): ground_truth = int(data.answer) parsed_answer = self.extract_answer(test_answer) diff --git a/verl/utils/reward_score/synlogic/object_properties_verifier.py b/verl/utils/reward_score/synlogic/object_properties_verifier.py index ca6a1f97..fba6aef2 100644 --- a/verl/utils/reward_score/synlogic/object_properties_verifier.py +++ b/verl/utils/reward_score/synlogic/object_properties_verifier.py @@ -1,7 +1,6 @@ import re from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END -from verl.utils.py_functional import timeout_limit class ObjectPropertiesVerifier(Verifier): @@ -10,7 +9,6 @@ class ObjectPropertiesVerifier(Verifier): """ def verify(self, data: Data, test_answer: str): try: - @timeout_limit(seconds=10) def _verify_with_timeout(): ground_truth = int(data.answer) parsed_answer_str = self.extract_answer(test_answer) diff --git a/verl/utils/reward_score/synlogic/operation_verifier.py b/verl/utils/reward_score/synlogic/operation_verifier.py index 3a2fa810..cdb0a83f 100644 --- a/verl/utils/reward_score/synlogic/operation_verifier.py +++ b/verl/utils/reward_score/synlogic/operation_verifier.py @@ -2,7 +2,6 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import math_verify -from verl.utils.py_functional import timeout_limit class OperationVerifier(Verifier): @@ -11,7 +10,6 @@ class OperationVerifier(Verifier): """ def verify(self, data: Data, test_answer: str): try: - @timeout_limit(seconds=25) def _verify_with_timeout(): ground_truth = math_verify.parse(data.answer, parsing_timeout=10) parsed_answer = math_verify.parse(test_answer, parsing_timeout=10) diff --git a/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py b/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py index bfc6b094..ec5c8ef3 100644 --- a/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py +++ b/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py @@ -3,7 +3,6 @@ import re import json import ast -from verl.utils.py_functional import timeout_limit class SkyscraperPuzzleVerifier(Verifier): @@ -19,7 +18,6 @@ def verify(self, data: Data, test_solution: str): @return: 回答是否正确的布尔值 """ try: - @timeout_limit(seconds=10) def _verify_with_timeout(): # 获取游戏元数据 metadata = data.metadata diff --git a/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py b/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py index bd41bc8d..fbee5a7f 100644 --- a/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py +++ b/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py @@ -2,7 +2,6 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import math_verify -from verl.utils.py_functional import timeout_limit class SpaceReasoningTreeVerifier(Verifier): """ @@ -10,7 +9,6 @@ class SpaceReasoningTreeVerifier(Verifier): """ def verify(self, data: Data, test_answer: str): try: - @timeout_limit(seconds=10) def _verify_with_timeout(): test_answer_extracted = self.extract_answer(test_answer) if test_answer_extracted is None: diff --git a/verl/utils/reward_score/synlogic/space_reasoning_verifier.py b/verl/utils/reward_score/synlogic/space_reasoning_verifier.py index ed51d2a9..e2a14756 100644 --- a/verl/utils/reward_score/synlogic/space_reasoning_verifier.py +++ b/verl/utils/reward_score/synlogic/space_reasoning_verifier.py @@ -2,7 +2,6 @@ from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END import math_verify -from verl.utils.py_functional import timeout_limit class SpaceReasoningVerifier(Verifier): @@ -11,7 +10,6 @@ class SpaceReasoningVerifier(Verifier): """ def verify(self, data: Data, test_answer: str): try: - @timeout_limit(seconds=10) def _verify_with_timeout(): test_answer_extracted = self.extract_answer(test_answer) if test_answer_extracted is None: diff --git a/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py b/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py index 15f2695e..161aff88 100644 --- a/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py +++ b/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py @@ -3,7 +3,6 @@ import re import json import ast -from verl.utils.py_functional import timeout_limit class StarPlacementPuzzleVerifier(Verifier): """ @@ -18,7 +17,6 @@ def verify(self, data: Data, test_solution: str): @return: 回答是否正确的布尔值 """ try: - @timeout_limit(seconds=10) def _verify_with_timeout(): star_coords = self.extract_answer(test_solution) # 获取游戏元数据 diff --git a/verl/utils/reward_score/synlogic/time_sequence_verifier.py b/verl/utils/reward_score/synlogic/time_sequence_verifier.py index 6133f437..4b37ac64 100644 --- a/verl/utils/reward_score/synlogic/time_sequence_verifier.py +++ b/verl/utils/reward_score/synlogic/time_sequence_verifier.py @@ -3,7 +3,6 @@ from .data import Data from .verifier import Verifier import re -from verl.utils.py_functional import timeout_limit class TimeSequenceVerifier(Verifier): """ @@ -18,7 +17,6 @@ def verify(self, data: Data, test_solution: str): @return: 回答是否正确的布尔值 """ try: - @timeout_limit(seconds=10) def _verify_with_timeout(): test_answer = self.extract_answer(test_solution) # 解析元数据 diff --git a/verl/utils/reward_score/synlogic/web_of_lies_verifier.py b/verl/utils/reward_score/synlogic/web_of_lies_verifier.py index 751ea9c3..d0040661 100644 --- a/verl/utils/reward_score/synlogic/web_of_lies_verifier.py +++ b/verl/utils/reward_score/synlogic/web_of_lies_verifier.py @@ -1,7 +1,6 @@ import re from .data import Data from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END -from verl.utils.py_functional import timeout_limit class WebOfLiesVerifier(Verifier): """ @@ -16,7 +15,6 @@ def verify(self, data: Data, test_solution: str): @return: 回答是否正确的布尔值 """ try: - @timeout_limit(seconds=10) def _verify_with_timeout(): test_answer = self.extract_answer(test_solution) # 获取预期答案和测试答案 diff --git a/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py b/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py index 3f2366b2..aff680ce 100644 --- a/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py +++ b/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py @@ -1,7 +1,6 @@ import re from .data import Data from .verifier import Verifier -from verl.utils.py_functional import timeout_limit class WordSortingMistakeVerifier(Verifier): """ @@ -9,7 +8,6 @@ class WordSortingMistakeVerifier(Verifier): """ def verify(self, data: Data, test_answer: str): try: - @timeout_limit(seconds=10) def _verify_with_timeout(): ground_truth = data.answer if data.answer is not None else "No" parsed_answer = self.extract_answer(test_answer) diff --git a/verl/utils/reward_score/synlogic/word_sorting_verifier.py b/verl/utils/reward_score/synlogic/word_sorting_verifier.py index df3d6cf2..ed3ffe97 100644 --- a/verl/utils/reward_score/synlogic/word_sorting_verifier.py +++ b/verl/utils/reward_score/synlogic/word_sorting_verifier.py @@ -1,7 +1,6 @@ import re from .data import Data from .verifier import Verifier -from verl.utils.py_functional import timeout_limit class WordSortingVerifier(Verifier): """ @@ -14,7 +13,6 @@ def str2list(self, answer_str): def verify(self, data: Data, test_answer: str): try: - @timeout_limit(seconds=10) def _verify_with_timeout(): ground_truth = self.str2list(data.answer) parsed_answer = self.str2list(self.extract_answer(test_answer)) diff --git a/verl/utils/reward_score/synlogic/wordscapes_verifier.py b/verl/utils/reward_score/synlogic/wordscapes_verifier.py index 796f06c0..3e62888c 100644 --- a/verl/utils/reward_score/synlogic/wordscapes_verifier.py +++ b/verl/utils/reward_score/synlogic/wordscapes_verifier.py @@ -5,7 +5,6 @@ import json import re from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END -from verl.utils.py_functional import timeout_limit debug_mode = False @@ -28,7 +27,6 @@ def verify(self, data, test_solution: str): """ try: - @timeout_limit(seconds=10) def _verify_with_timeout(): extracted_answer = self.extract_answer(test_solution) if not extracted_answer: diff --git a/verl/utils/reward_score/zebra_puzzle.py b/verl/utils/reward_score/zebra_puzzle.py index 90ac730d..27aae643 100644 --- a/verl/utils/reward_score/zebra_puzzle.py +++ b/verl/utils/reward_score/zebra_puzzle.py @@ -4,8 +4,6 @@ import operator import json -from verl.utils.py_functional import timeout_limit - def extract_solution(solution_str): answer_pattern = r'(.*?)' @@ -53,7 +51,6 @@ def compute_accuracy(answer, ground_truth): return accuracy def compute_score(solution_str, ground_truth, extra_info: any = None, method='strict', timeout: float = 10.0): - @timeout_limit(seconds=timeout) def _compute_with_timeout(): predicted_arrangement = extract_solution(solution_str) @@ -68,9 +65,6 @@ def _compute_with_timeout(): try: score = _compute_with_timeout() - except TimeoutError: - print("Computation timed out in zebra_puzzle") - score = 0.0 except Exception as e: print(f"Error in compute_score in zebra_puzzle: {e}") score = 0.0 diff --git a/verl/workers/reward_manager/async_mp.py b/verl/workers/reward_manager/async_mp.py index 7150b66c..4577b11c 100644 --- a/verl/workers/reward_manager/async_mp.py +++ b/verl/workers/reward_manager/async_mp.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import gc from collections import defaultdict from concurrent.futures import ProcessPoolExecutor from functools import partial @@ -116,6 +117,9 @@ async def parallel_compute_score_async( actual_idx = start_idx + i results[actual_idx] = result + # Force garbage collection after each batch to prevent memory accumulation + gc.collect() + except Exception: for pid, proc in executor._processes.items(): try: @@ -329,6 +333,41 @@ def __call__(self, data: DataProto, return_dict: bool = False): # print(f"[DEBUG] Non-zero elements in reward_tensor: {(reward_tensor != 0).sum().item()}") # print(f"[DEBUG] Unique data sources processed: {list(already_print_data_sources.keys())}") + # Aggregate and print timing statistics by reward type + timing_by_type = defaultdict(lambda: {"count": 0, "total": 0.0, "max": 0.0}) + for key_list in reward_extra_info.get("_reward_time", []), reward_extra_info.get("_data_source", []): + pass # Just to check if keys exist + + if "_reward_time" in reward_extra_info and "_data_source" in reward_extra_info: + for reward_time, ds in zip( + reward_extra_info["_reward_time"], reward_extra_info["_data_source"], strict=False + ): + # Extract prefix (e.g., "codegen" from "codegen__deduped_leetcode2k") + if "__" in ds: + prefix = ds.split("__")[0] + elif "_" in ds: + prefix = ds.split("_")[0] + else: + prefix = ds + timing_by_type[prefix]["count"] += 1 + timing_by_type[prefix]["total"] += reward_time + timing_by_type[prefix]["max"] = max(timing_by_type[prefix]["max"], reward_time) + + # Print timing summary sorted by total time (descending) + print("\n=== REWARD TIMING BY TYPE ===") + for rtype, stats in sorted(timing_by_type.items(), key=lambda x: -x[1]["total"]): + avg = stats["total"] / max(stats["count"], 1) + print( + f" {rtype:20s}: {stats['count']:5d} samples, " + f"avg={avg*1000:8.2f}ms, max={stats['max']*1000:8.2f}ms, total={stats['total']:8.2f}s" + ) + print("=" * 50 + "\n") + + # Remove timing metadata keys before returning - they have inconsistent lengths + # when samples timeout/fail and would cause IndexError in batch reordering + reward_extra_info.pop("_reward_time", None) + reward_extra_info.pop("_data_source", None) + if return_dict: return { "reward_tensor": reward_tensor, From 2f0dc5bfdc951f9eb2afbf03119dd4617baa24bf Mon Sep 17 00:00:00 2001 From: Taylor Killian Date: Sat, 14 Feb 2026 00:01:28 +0000 Subject: [PATCH 20/20] sync --- scripts/train/k2p_hero_grpo_newData.sh | 37 +++++++++++++------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/scripts/train/k2p_hero_grpo_newData.sh b/scripts/train/k2p_hero_grpo_newData.sh index fcee39a9..d9a61fb1 100644 --- a/scripts/train/k2p_hero_grpo_newData.sh +++ b/scripts/train/k2p_hero_grpo_newData.sh @@ -1,7 +1,7 @@ #!/bin/bash -#SBATCH --job-name=grpo-stage2-k2pRL-easy50k-7domains -#SBATCH --nodes=64 -#SBATCH --ntasks=64 +#SBATCH --job-name=grpo-stage1-k2pRL-mainMathOnly +#SBATCH --nodes=32 +#SBATCH --ntasks=32 #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:8 #SBATCH --cpus-per-task=96 @@ -14,7 +14,7 @@ # commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410] # =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="grpo-stage2-k2pRL-easy50k-7domains-415354" # Fill in the checkpoint directory name to resume from, otherwise from scratch +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty @@ -71,19 +71,19 @@ iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_pro # List of datasets to include (filename only) # Comment out lines to exclude specific datasets dataset_names=( - "codegen__deduped_leetcode2k_2.4k.parquet" - "codegen__deduped_livecodebench_599.parquet" - "codegen__deduped_primeintellect_9.6k.parquet" - "codegen__deduped_taco_11.1k.parquet" - "ifbench__fixed_85.6k.parquet" "math__combined_118.2k.part1.parquet" "math__combined_118.2k.part2.parquet" "omni_math_4.43k_dedup.parquet" - "stem__nemotron_13.3k.parquet" - "stem__web_31.7k.parquet" - "table__hitab_7.4k.parquet" - "table__multihier_2.9k.parquet" ) + # "stem__nemotron_13.3k.parquet" + # "stem__web_31.7k.parquet" + # "table__hitab_7.4k.parquet" + # "table__multihier_2.9k.parquet" +# "codegen__deduped_leetcode2k_2.4k.parquet" +# "codegen__deduped_livecodebench_599.parquet" +# "codegen__deduped_primeintellect_9.6k.parquet" +# "codegen__deduped_taco_11.1k.parquet" +# "ifbench__fixed_85.6k.parquet" # "simulation__codeio_fixed_12.1k.parquet" # "logic__arcagi1_297.parquet" # "logic__arcagi2_653.parquet" @@ -166,14 +166,15 @@ if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet # test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" # Full data mixture (uncomment to use) -test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}', +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}', # =================== Model =================== # BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT -# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) # BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k -BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface +# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface # =================== Logging =================== WANDB_PROJECT=k2plus_rl @@ -228,7 +229,7 @@ clip_ratio_low=0.2 clip_ratio_high=0.28 max_prompt_length=$((1024 * 4)) -max_response_length=$((1024 * 64)) +max_response_length=$((1024 * 32)) enable_overlong_buffer=False overlong_buffer_len=$((1024 * 12)) overlong_penalty_factor=1.0 @@ -342,7 +343,7 @@ offload=True actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.model.enable_activation_offload=${offload} \ actor_rollout_ref.model.use_liger=True \ - reward_model.reward_manager=dapo \ + reward_model.reward_manager=async_multi_process \ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ reward_model.overlong_buffer.len=${overlong_buffer_len} \ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \