diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py
index cffb342e5..2e1441300 100644
--- a/recipe/dapo/dapo_ray_trainer.py
+++ b/recipe/dapo/dapo_ray_trainer.py
@@ -251,9 +251,10 @@ def fit(self):
batch = new_batch if batch is None else DataProto.concat([batch, new_batch])
prompt_bsz = self.config.data.train_batch_size
- if num_prompt_in_batch < prompt_bsz:
+ max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
+ if num_prompt_in_batch < prompt_bsz and max_num_gen_batches > 1: # Added by Reasoning360 TWK NOTE: second condition is to account for when we have zero-variance filtering but are not dynamically growing the batch...
print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
- max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
+ # max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
print(f"{num_gen_batches=}. Keep generating...")
progress_bar.update(1)
@@ -267,9 +268,14 @@ def fit(self):
+ " You could also try set max_num_gen_batches=0 to enable endless trials."
)
else:
- # Align the batch
- traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
- batch = batch[:traj_bsz]
+ # Added by Reasoning360, need to account for when our batch is smaller due to zero-variance filtering
+ if num_prompt_in_batch >= prompt_bsz:
+ # Align the batch
+ traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
+ batch = batch[:traj_bsz]
+ else:
+ # TWK TODO!!!: RESCALE THIS SO THAT THE BATCH*N IS DIVISIBLE BY k_partitions (n_gpus...)
+ print(f"Final {num_prompt_in_batch=} < {prompt_bsz=} after {num_gen_batches=} generation batches. Proceeding with smaller batch...")
# === Updating ===
diff --git a/scripts/tools/serve_llm_as_verifier.sh b/scripts/tools/serve_llm_as_verifier.sh
index 0d9019385..1f4345b1e 100644
--- a/scripts/tools/serve_llm_as_verifier.sh
+++ b/scripts/tools/serve_llm_as_verifier.sh
@@ -1,6 +1,6 @@
#!/bin/bash
#SBATCH --job-name=server_llm_as_verifier
-#SBATCH --partition=main
+#SBATCH --partition=higherprio
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=64
diff --git a/scripts/tools/serve_math_llm_as_verifier.sh b/scripts/tools/serve_math_llm_as_verifier.sh
new file mode 100644
index 000000000..605988255
--- /dev/null
+++ b/scripts/tools/serve_math_llm_as_verifier.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+#SBATCH --job-name=server_math_llm_as_verifier
+#SBATCH --partition=higherprio
+#SBATCH --nodes=1
+#SBATCH --ntasks=1
+#SBATCH --cpus-per-task=64
+#SBATCH --gres=gpu:8
+#SBATCH --time=720:00:00
+#SBATCH --output=slurm/serve_math_llm_as_verifier_%j.log
+#SBATCH --error=slurm/serve_math_llm_as_verifier_%j.log
+
+
+# (1) detect this node’s primary IP
+NODE_IP=$(hostname -I | awk '{print $1}')
+echo "Detected NODE_IP = $NODE_IP"
+
+# (2) export judge URL for downstream clients
+export MATH_LLM_JUDGE_URL="http://${NODE_IP}:8000"
+echo "MATH_LLM_JUDGE_URL=$MATH_LLM_JUDGE_URL"
+
+# (3) launch the vLLM server bound to that IP
+vllm serve openai/gpt-oss-120b --host "$NODE_IP" --data-parallel-size 8 --enable-expert-parallel
\ No newline at end of file
diff --git a/scripts/train/k2p/k2p_hero_grpo_stage2.sh b/scripts/train/k2p/k2p_hero_grpo_stage2.sh
new file mode 100644
index 000000000..2d37a7e2b
--- /dev/null
+++ b/scripts/train/k2p/k2p_hero_grpo_stage2.sh
@@ -0,0 +1,370 @@
+#!/bin/bash
+#SBATCH --job-name=grpo-stage2-k2pRL-dataMix2
+#SBATCH --nodes=64
+#SBATCH --ntasks=64
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=96
+#SBATCH --mem=0
+#SBATCH --output=slurm/%x-%j.log
+#SBATCH --error=slurm/%x-%j.log
+#SBATCH --exclusive
+#SBATCH --time=720:00:00
+#SBATCH --partition=higherprio
+#SBATCH --exclude=azure-uk-hpc-H200-instance-337
+
+# commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410]
+# job name: grpo-stage2-k2pRL-easy50k-7domains
+# job name: grpo-k2p-newFiltered-64k-fullData-finalInstruct
+
+# =================== Frequently Used Variables ===================
+RESUME_CKPT_DIR_NAME="grpo-stage2-k2pRL-dataMix2-417843" # Fill in the checkpoint directory name to resume from, otherwise from scratch
+# export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+# export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty
+export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-036:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-058:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty
+
+# =================== Cluster Environment ===================
+# export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/
+export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl//bin/
+export ROCR_VISIBLE_DEVICES=None
+export NCCL_TIMEOUT_SECONDS=4800000
+export RAY_memory_usage_threshold=0.95 # Increase Ray memory threshold before killing workers
+export OMPI_MCA_coll_hcoll_enable=0 \
+TORCH_NCCL_ENABLE_MONITORING=0 \
+CUDA_DEVICE_ORDER=PCI_BUS_ID \
+NCCL_SOCKET_IFNAME=eth0 \
+UCX_TLS=rc \
+UCX_NET_DEVICES=mlx5_ib0:1 \
+NCCL_DEBUG=WARN \
+NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \
+NCCL_IB_PCI_RELAXED_ORDERING=1 \
+NCCL_IB_QPS_PER_CONNECTION=4 \
+NCCL_IGNORE_CPU_AFFINITY=1 \
+NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \
+NCCL_PXN_DISABLE=1 \
+NCCL_MIN_NCHANNELS=32 \
+SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \
+SHARP_COLL_ENABLE_SAT=1 \
+SHARP_COLL_LOG_LEVEL=3 \
+SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \
+NCCL_COLLNET_ENABLE=1 \
+NCCL_NVLS_ENABLE=0
+
+
+# Get the list of allocated nodes
+nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") )
+echo "Nodes to check: ${nodes[@]}"
+
+# We'll track PIDs so we can wait on them and detect errors
+declare -A pids
+export head_node=${nodes[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+port=6379
+address_head=$head_node_ip:$port
+
+export worker_num=$SLURM_NNODES
+export HYDRA_FULL_ERROR=1
+export VLLM_USE_V1=1
+
+# =================== Data Mixture ===================
+
+# Training Data Configuration
+DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_2"
+train_file_list=()
+id_val_file_list=()
+
+iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet"
+
+# List of datasets to include (filename only)
+# Comment out lines to exclude specific datasets
+dataset_names=(
+ "codegen__deduped_leetcode2k_2.4k.parquet"
+ "codegen__deduped_livecodebench_599.parquet"
+ "codegen__deduped_primeintellect_9.6k.parquet"
+ "codegen__deduped_taco_11.1k.parquet"
+ "ifbench__fixed_85.6k.parquet"
+ "simulation__codeio_fixed_12.1k.parquet"
+ "logic__arcagi1_297.parquet"
+ "logic__arcagi2_653.parquet"
+ "logic__barc_3.4k.parquet"
+ "logic__graph_logical_dataset_1.4k.parquet"
+ "logic__ordering_puzzle_dataset_2.9k.parquet"
+ "logic__reasoning_gym_40.6k.parquet"
+ "logic__synlogic_12.1k.parquet"
+ "logic__zebra_puzzle_dataset_5.0k.parquet"
+ "math__combined_118.2k.part1.parquet"
+ "math__combined_118.2k.part2.parquet"
+ "omni_math_4.43k.parquet"
+ "stem__nemotron_13.3k.parquet"
+ "stem__web_31.7k.parquet"
+ "table__hitab_7.4k.parquet"
+ "table__multihier_2.9k.parquet"
+)
+
+echo "Collecting training files from ${DATA_MIX_DIR}..."
+
+# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions" "easy_questions"
+for dataset in "${dataset_names[@]}"; do
+ for subdir in "main_questions" "131k_context_questions" "impossible_questions"; do
+ file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+ if [ -f "$file_path" ]; then
+ echo "Adding: $file_path"
+ train_file_list+=("'$file_path'")
+ fi
+ done
+done
+
+# for dataset in "${dataset_names[@]}"; do
+# for subdir in "131k_context_questions"; do
+# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+# if [ -f "$file_path" ]; then
+# echo "Adding: $file_path"
+# id_val_file_list+=("'$file_path'")
+# fi
+# done
+# done
+# id_val_file_list+=("'$iq400_path'")
+
+# Join with comma to form Python list string
+IFS=,
+train_files="[${train_file_list[*]}]"
+# id_val_files="[${id_val_file_list[*]}]"
+unset IFS
+
+echo "Total training files found: ${#train_file_list[@]}"
+# echo "Total ID validation files found: ${#id_val_file_list[@]}"
+
+# Test Data Configuration
+TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len
+# Math (test)
+math_test_path=${TEST_DATA_DIR}/math__math_500.parquet
+aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet
+aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet
+amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet
+
+# Code (test)
+humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet
+mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet
+livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet
+
+# Logic (test)
+zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet
+reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet
+synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet
+arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet
+# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet
+# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet
+
+# Table (test)
+multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet
+hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet
+
+# Stem (test)
+nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet
+gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet
+supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet
+
+# Instruction follow (test)
+if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet
+if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet
+
+# Focused data mixture (math, code, stem)
+# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']"
+# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']"
+
+# Full data mixture (uncomment to use)
+test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}', '${reasoninggym_test_path}'
+
+
+# =================== Model ===================
+# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k)
+BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k
+# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface
+
+# =================== Logging ===================
+WANDB_PROJECT=k2plus_rl
+WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/}
+# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491"
+
+# If RESUME_CKPT_DIR is not empty, resume from the checkpoint
+if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then
+ WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME"
+fi
+
+
+# =================== Ray start ===================
+# ray stop at all nodes
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop
+
+sleep 10
+# Remove existing Ray cluster
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster
+
+# Start Ray head node
+srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block &
+
+sleep 10
+
+# Start Ray worker nodes
+for ((i = 1; i < worker_num; i++)); do
+ node_i=${nodes[$i]}
+ echo "Starting WORKER $i at $node_i"
+ srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --address "$address_head" \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block &
+done
+sleep 10
+
+
+# =================== RL Config ===================
+# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline.
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=0.28
+
+max_prompt_length=$((1024 * 4))
+max_response_length=$((1024 * 64))
+enable_overlong_buffer=False
+overlong_buffer_len=$((1024 * 12))
+overlong_penalty_factor=1.0
+
+loss_agg_mode="token-mean"
+rollout_dtype="float16"
+
+enable_filter_groups=False
+filter_groups_metric=acc
+max_num_gen_batches=10
+train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n
+gen_prompt_bsz=$((train_prompt_bsz * 1))
+n_resp_per_prompt=16
+train_prompt_mini_bsz=256 # model grad update batchsize
+
+# Algorithm
+temperature=1.2
+val_temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+
+# Training config
+sp_size=16 # Disable sequence parallelism
+gen_tp=4
+gen_max_num_seqs=1024
+infer_micro_batch_size=null
+train_micro_batch_size=null
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow
+infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow
+offload=True
+
+# =================== Start RL training ===================
+"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \
+ --config-path=config \
+ --config-name="dapo_fsdp_config.yaml" \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ algorithm.filter_groups.enable=${enable_filter_groups} \
+ algorithm.filter_groups.metric=${filter_groups_metric} \
+ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
+ data.train_files="$train_files" \
+ data.val_files="$test_files" \
+ data.prompt_key=prompt \
+ data.truncation='right' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ data.gen_batch_size=${gen_prompt_bsz} \
+ actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \
+ actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
+ actor_rollout_ref.actor.strategy="fsdp2" \
+ actor_rollout_ref.actor.optim.lr=5e-7 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.optim.warmup_style=constant \
+ actor_rollout_ref.actor.optim.min_lr_ratio=0. \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
+ actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
+ actor_rollout_ref.actor.entropy_checkpointing=True \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \
+ actor_rollout_ref.rollout.disable_log_stats=False \
+ actor_rollout_ref.rollout.enforce_eager=True \
+ actor_rollout_ref.rollout.enable_prefix_caching=True \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\
+ actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.model.path=$BASE_MODEL \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.rollout.multi_turn.enable=False \
+ actor_rollout_ref.rollout.mode="sync" \
+ +actor_rollout_ref.model.override_config.attention_dropout=0. \
+ +actor_rollout_ref.model.override_config.embd_pdrop=0. \
+ +actor_rollout_ref.model.override_config.resid_pdrop=0. \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.model.enable_activation_offload=${offload} \
+ actor_rollout_ref.model.use_liger=True \
+ reward_model.reward_manager=async_multi_process \
+ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
+ reward_model.overlong_buffer.len=${overlong_buffer_len} \
+ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
+ +reward_model.reward_kwargs.num_processes=64 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name=${WANDB_PROJECT} \
+ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \
+ trainer.val_before_train=False \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=$worker_num \
+ trainer.save_freq=5 \
+ trainer.test_freq=5 \
+ trainer.total_epochs=5 \
+ trainer.resume_mode=auto \
+ trainer.max_actor_ckpt_to_keep=3
+ # trainer.log_val_generations=50
diff --git a/scripts/train/k2p/k2p_hero_grpo_stage2_fixed.sh b/scripts/train/k2p/k2p_hero_grpo_stage2_fixed.sh
new file mode 100644
index 000000000..4d0070e04
--- /dev/null
+++ b/scripts/train/k2p/k2p_hero_grpo_stage2_fixed.sh
@@ -0,0 +1,369 @@
+#!/bin/bash
+#SBATCH --job-name=grpo-stage2-k2pRL-dataMix2-fixed
+#SBATCH --nodes=64
+#SBATCH --ntasks=64
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=96
+#SBATCH --mem=0
+#SBATCH --output=slurm/%x-%j.log
+#SBATCH --error=slurm/%x-%j.log
+#SBATCH --exclusive
+#SBATCH --time=720:00:00
+#SBATCH --partition=higherprio
+#SBATCH --exclude=azure-uk-hpc-H200-instance-337
+
+# commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410]
+# job name: grpo-stage2-k2pRL-easy50k-7domains
+# job name: grpo-k2p-newFiltered-64k-fullData-finalInstruct
+
+# =================== Frequently Used Variables ===================
+RESUME_CKPT_DIR_NAME="grpo-stage2-k2pRL-dataMix2-417843" # Fill in the checkpoint directory name to resume from, otherwise from scratch
+# export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+# export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty
+export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-117:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-233:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty
+
+# =================== Cluster Environment ===================
+# export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/
+export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl//bin/
+export ROCR_VISIBLE_DEVICES=None
+export NCCL_TIMEOUT_SECONDS=4800000
+export OMPI_MCA_coll_hcoll_enable=0 \
+TORCH_NCCL_ENABLE_MONITORING=0 \
+CUDA_DEVICE_ORDER=PCI_BUS_ID \
+NCCL_SOCKET_IFNAME=eth0 \
+UCX_TLS=rc \
+UCX_NET_DEVICES=mlx5_ib0:1 \
+NCCL_DEBUG=WARN \
+NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \
+NCCL_IB_PCI_RELAXED_ORDERING=1 \
+NCCL_IB_QPS_PER_CONNECTION=4 \
+NCCL_IGNORE_CPU_AFFINITY=1 \
+NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \
+NCCL_PXN_DISABLE=1 \
+NCCL_MIN_NCHANNELS=32 \
+SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \
+SHARP_COLL_ENABLE_SAT=1 \
+SHARP_COLL_LOG_LEVEL=3 \
+SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \
+NCCL_COLLNET_ENABLE=1 \
+NCCL_NVLS_ENABLE=0
+
+
+# Get the list of allocated nodes
+nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") )
+echo "Nodes to check: ${nodes[@]}"
+
+# We'll track PIDs so we can wait on them and detect errors
+declare -A pids
+export head_node=${nodes[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+port=6379
+address_head=$head_node_ip:$port
+
+export worker_num=$SLURM_NNODES
+export HYDRA_FULL_ERROR=1
+export VLLM_USE_V1=1
+
+# =================== Data Mixture ===================
+
+# Training Data Configuration
+DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_2"
+train_file_list=()
+id_val_file_list=()
+
+iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet"
+
+# List of datasets to include (filename only)
+# Comment out lines to exclude specific datasets
+dataset_names=(
+ "codegen__deduped_leetcode2k_2.4k.parquet"
+ "codegen__deduped_livecodebench_599.parquet"
+ "codegen__deduped_primeintellect_9.6k.parquet"
+ "codegen__deduped_taco_11.1k.parquet"
+ "ifbench__fixed_85.6k.parquet"
+ "simulation__codeio_fixed_12.1k.parquet"
+ "logic__arcagi1_297.parquet"
+ "logic__arcagi2_653.parquet"
+ "logic__barc_3.4k.parquet"
+ "logic__graph_logical_dataset_1.4k.parquet"
+ "logic__ordering_puzzle_dataset_2.9k.parquet"
+ "logic__reasoning_gym_40.6k.parquet"
+ "logic__synlogic_12.1k.parquet"
+ "logic__zebra_puzzle_dataset_5.0k.parquet"
+ "math__combined_118.2k.part1.parquet"
+ "math__combined_118.2k.part2.parquet"
+ "omni_math_4.43k.parquet"
+ "stem__nemotron_13.3k.parquet"
+ "stem__web_31.7k.parquet"
+ "table__hitab_7.4k.parquet"
+ "table__multihier_2.9k.parquet"
+)
+
+echo "Collecting training files from ${DATA_MIX_DIR}..."
+
+# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions" "easy_questions"
+for dataset in "${dataset_names[@]}"; do
+ for subdir in "main_questions" "131k_context_questions" "impossible_questions"; do
+ file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+ if [ -f "$file_path" ]; then
+ echo "Adding: $file_path"
+ train_file_list+=("'$file_path'")
+ fi
+ done
+done
+
+# for dataset in "${dataset_names[@]}"; do
+# for subdir in "131k_context_questions"; do
+# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+# if [ -f "$file_path" ]; then
+# echo "Adding: $file_path"
+# id_val_file_list+=("'$file_path'")
+# fi
+# done
+# done
+# id_val_file_list+=("'$iq400_path'")
+
+# Join with comma to form Python list string
+IFS=,
+train_files="[${train_file_list[*]}]"
+# id_val_files="[${id_val_file_list[*]}]"
+unset IFS
+
+echo "Total training files found: ${#train_file_list[@]}"
+# echo "Total ID validation files found: ${#id_val_file_list[@]}"
+
+# Test Data Configuration
+TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len
+# Math (test)
+math_test_path=${TEST_DATA_DIR}/math__math_500.parquet
+aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet
+aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet
+amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet
+
+# Code (test)
+humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet
+mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet
+livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet
+
+# Logic (test)
+zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet
+reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet
+synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet
+arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet
+# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet
+# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet
+
+# Table (test)
+multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet
+hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet
+
+# Stem (test)
+nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet
+gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet
+supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet
+
+# Instruction follow (test)
+if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet
+if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet
+
+# Focused data mixture (math, code, stem)
+# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']"
+# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']"
+
+# Full data mixture (uncomment to use)
+test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}', '${reasoninggym_test_path}'
+
+
+# =================== Model ===================
+# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k)
+BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k
+# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface
+
+# =================== Logging ===================
+WANDB_PROJECT=k2plus_rl
+WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/}
+# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491"
+
+# If RESUME_CKPT_DIR is not empty, resume from the checkpoint
+if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then
+ WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME"
+fi
+
+
+# =================== Ray start ===================
+# ray stop at all nodes
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop
+
+sleep 10
+# Remove existing Ray cluster
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster
+
+# Start Ray head node
+srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block &
+
+sleep 10
+
+# Start Ray worker nodes
+for ((i = 1; i < worker_num; i++)); do
+ node_i=${nodes[$i]}
+ echo "Starting WORKER $i at $node_i"
+ srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --address "$address_head" \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block &
+done
+sleep 10
+
+
+# =================== RL Config ===================
+# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline.
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=0.28
+
+max_prompt_length=$((1024 * 4))
+max_response_length=$((1024 * 64))
+enable_overlong_buffer=False
+overlong_buffer_len=$((1024 * 12))
+overlong_penalty_factor=1.0
+
+loss_agg_mode="token-mean"
+rollout_dtype="float16"
+
+enable_filter_groups=False
+filter_groups_metric=acc
+max_num_gen_batches=10
+train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n
+gen_prompt_bsz=$((train_prompt_bsz * 1))
+n_resp_per_prompt=16
+train_prompt_mini_bsz=256 # model grad update batchsize
+
+# Algorithm
+temperature=1.2
+val_temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+
+# Training config
+sp_size=16 # Disable sequence parallelism
+gen_tp=4
+gen_max_num_seqs=1024
+infer_micro_batch_size=null
+train_micro_batch_size=null
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow
+infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow
+offload=True
+
+# =================== Start RL training ===================
+"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \
+ --config-path=config \
+ --config-name="dapo_fsdp_config.yaml" \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ algorithm.filter_groups.enable=${enable_filter_groups} \
+ algorithm.filter_groups.metric=${filter_groups_metric} \
+ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
+ data.train_files="$train_files" \
+ data.val_files="$test_files" \
+ data.prompt_key=prompt \
+ data.truncation='right' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ data.gen_batch_size=${gen_prompt_bsz} \
+ actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \
+ actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=48000 \
+ actor_rollout_ref.actor.strategy="fsdp2" \
+ actor_rollout_ref.actor.optim.lr=5e-7 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.optim.warmup_style=constant \
+ actor_rollout_ref.actor.optim.min_lr_ratio=0. \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
+ actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
+ actor_rollout_ref.actor.entropy_checkpointing=True \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \
+ actor_rollout_ref.rollout.disable_log_stats=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.enable_prefix_caching=True \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\
+ actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.model.path=$BASE_MODEL \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.rollout.multi_turn.enable=False \
+ actor_rollout_ref.rollout.mode="sync" \
+ +actor_rollout_ref.model.override_config.attention_dropout=0. \
+ +actor_rollout_ref.model.override_config.embd_pdrop=0. \
+ +actor_rollout_ref.model.override_config.resid_pdrop=0. \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.model.enable_activation_offload=${offload} \
+ actor_rollout_ref.model.use_liger=True \
+ reward_model.reward_manager=async_multi_process \
+ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
+ reward_model.overlong_buffer.len=${overlong_buffer_len} \
+ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
+ +reward_model.reward_kwargs.num_processes=64 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name=${WANDB_PROJECT} \
+ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \
+ trainer.val_before_train=False \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=$worker_num \
+ trainer.save_freq=5 \
+ trainer.test_freq=5 \
+ trainer.total_epochs=5 \
+ trainer.resume_mode=auto \
+ trainer.max_actor_ckpt_to_keep=3
+ # trainer.log_val_generations=50
diff --git a/scripts/train/k2p/k2p_hero_grpo_stage2_fixed_n64.sh b/scripts/train/k2p/k2p_hero_grpo_stage2_fixed_n64.sh
new file mode 100644
index 000000000..d6524857a
--- /dev/null
+++ b/scripts/train/k2p/k2p_hero_grpo_stage2_fixed_n64.sh
@@ -0,0 +1,369 @@
+#!/bin/bash
+#SBATCH --job-name=grpo-stage2-k2pRL-dataMix2-n64-fixed
+#SBATCH --nodes=64
+#SBATCH --ntasks=64
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=96
+#SBATCH --mem=0
+#SBATCH --output=slurm/%x-%j.log
+#SBATCH --error=slurm/%x-%j.log
+#SBATCH --exclusive
+#SBATCH --time=720:00:00
+#SBATCH --partition=higherprio
+#SBATCH --exclude=azure-uk-hpc-H200-instance-337
+
+# commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410]
+# job name: grpo-stage2-k2pRL-easy50k-7domains
+# job name: grpo-k2p-newFiltered-64k-fullData-finalInstruct
+
+# =================== Frequently Used Variables ===================
+RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch
+# export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+# export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty
+export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-036:8000,http://azure-uk-hpc-H200-instance-061:8000,http://azure-uk-hpc-H200-instance-062:8000,http://azure-uk-hpc-H200-instance-175:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-058:8000,http://azure-uk-hpc-H200-instance-176:8000,http://azure-uk-hpc-H200-instance-387:8000,http://azure-uk-hpc-H200-instance-388:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty
+
+# =================== Cluster Environment ===================
+# export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/
+export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl//bin/
+export ROCR_VISIBLE_DEVICES=None
+export NCCL_TIMEOUT_SECONDS=4800000
+export OMPI_MCA_coll_hcoll_enable=0 \
+TORCH_NCCL_ENABLE_MONITORING=0 \
+CUDA_DEVICE_ORDER=PCI_BUS_ID \
+NCCL_SOCKET_IFNAME=eth0 \
+UCX_TLS=rc \
+UCX_NET_DEVICES=mlx5_ib0:1 \
+NCCL_DEBUG=WARN \
+NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \
+NCCL_IB_PCI_RELAXED_ORDERING=1 \
+NCCL_IB_QPS_PER_CONNECTION=4 \
+NCCL_IGNORE_CPU_AFFINITY=1 \
+NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \
+NCCL_PXN_DISABLE=1 \
+NCCL_MIN_NCHANNELS=32 \
+SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \
+SHARP_COLL_ENABLE_SAT=1 \
+SHARP_COLL_LOG_LEVEL=3 \
+SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \
+NCCL_COLLNET_ENABLE=1 \
+NCCL_NVLS_ENABLE=0
+
+
+# Get the list of allocated nodes
+nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") )
+echo "Nodes to check: ${nodes[@]}"
+
+# We'll track PIDs so we can wait on them and detect errors
+declare -A pids
+export head_node=${nodes[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+port=6379
+address_head=$head_node_ip:$port
+
+export worker_num=$SLURM_NNODES
+export HYDRA_FULL_ERROR=1
+export VLLM_USE_V1=1
+
+# =================== Data Mixture ===================
+
+# Training Data Configuration
+DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_2"
+train_file_list=()
+id_val_file_list=()
+
+iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet"
+
+# List of datasets to include (filename only)
+# Comment out lines to exclude specific datasets
+dataset_names=(
+ "codegen__deduped_leetcode2k_2.4k.parquet"
+ "codegen__deduped_livecodebench_599.parquet"
+ "codegen__deduped_primeintellect_9.6k.parquet"
+ "codegen__deduped_taco_11.1k.parquet"
+ "ifbench__fixed_85.6k.parquet"
+ "simulation__codeio_fixed_12.1k.parquet"
+ "logic__arcagi1_297.parquet"
+ "logic__arcagi2_653.parquet"
+ "logic__barc_3.4k.parquet"
+ "logic__graph_logical_dataset_1.4k.parquet"
+ "logic__ordering_puzzle_dataset_2.9k.parquet"
+ "logic__reasoning_gym_40.6k.parquet"
+ "logic__synlogic_12.1k.parquet"
+ "logic__zebra_puzzle_dataset_5.0k.parquet"
+ "math__combined_118.2k.part1.parquet"
+ "math__combined_118.2k.part2.parquet"
+ "omni_math_4.43k.parquet"
+ "stem__nemotron_13.3k.parquet"
+ "stem__web_31.7k.parquet"
+ "table__hitab_7.4k.parquet"
+ "table__multihier_2.9k.parquet"
+)
+
+echo "Collecting training files from ${DATA_MIX_DIR}..."
+
+# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions" "easy_questions"
+for dataset in "${dataset_names[@]}"; do
+ for subdir in "main_questions" "131k_context_questions" "impossible_questions"; do
+ file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+ if [ -f "$file_path" ]; then
+ echo "Adding: $file_path"
+ train_file_list+=("'$file_path'")
+ fi
+ done
+done
+
+# for dataset in "${dataset_names[@]}"; do
+# for subdir in "131k_context_questions"; do
+# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+# if [ -f "$file_path" ]; then
+# echo "Adding: $file_path"
+# id_val_file_list+=("'$file_path'")
+# fi
+# done
+# done
+# id_val_file_list+=("'$iq400_path'")
+
+# Join with comma to form Python list string
+IFS=,
+train_files="[${train_file_list[*]}]"
+# id_val_files="[${id_val_file_list[*]}]"
+unset IFS
+
+echo "Total training files found: ${#train_file_list[@]}"
+# echo "Total ID validation files found: ${#id_val_file_list[@]}"
+
+# Test Data Configuration
+TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len
+# Math (test)
+math_test_path=${TEST_DATA_DIR}/math__math_500.parquet
+aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet
+aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet
+amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet
+
+# Code (test)
+humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet
+mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet
+livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet
+
+# Logic (test)
+zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet
+reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet
+synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet
+arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet
+# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet
+# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet
+
+# Table (test)
+multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet
+hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet
+
+# Stem (test)
+nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet
+gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet
+supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet
+
+# Instruction follow (test)
+if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet
+if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet
+
+# Focused data mixture (math, code, stem)
+# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']"
+# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']"
+
+# Full data mixture (uncomment to use)
+test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}', '${reasoninggym_test_path}'
+
+
+# =================== Model ===================
+# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k)
+BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k
+# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface
+
+# =================== Logging ===================
+WANDB_PROJECT=k2plus_rl
+WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/}
+# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491"
+
+# If RESUME_CKPT_DIR is not empty, resume from the checkpoint
+if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then
+ WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME"
+fi
+
+
+# =================== Ray start ===================
+# ray stop at all nodes
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop
+
+sleep 10
+# Remove existing Ray cluster
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster
+
+# Start Ray head node
+srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block &
+
+sleep 10
+
+# Start Ray worker nodes
+for ((i = 1; i < worker_num; i++)); do
+ node_i=${nodes[$i]}
+ echo "Starting WORKER $i at $node_i"
+ srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --address "$address_head" \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block &
+done
+sleep 10
+
+
+# =================== RL Config ===================
+# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline.
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=0.28
+
+max_prompt_length=$((1024 * 4))
+max_response_length=$((1024 * 64))
+enable_overlong_buffer=False
+overlong_buffer_len=$((1024 * 12))
+overlong_penalty_factor=1.0
+
+loss_agg_mode="token-mean"
+rollout_dtype="float16"
+
+enable_filter_groups=False
+filter_groups_metric=acc
+max_num_gen_batches=10
+train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n
+gen_prompt_bsz=$((train_prompt_bsz * 1))
+n_resp_per_prompt=64
+train_prompt_mini_bsz=256 # model grad update batchsize
+
+# Algorithm
+temperature=1.2
+val_temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+
+# Training config
+sp_size=16 # Disable sequence parallelism
+gen_tp=4
+gen_max_num_seqs=1024
+infer_micro_batch_size=null
+train_micro_batch_size=null
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow
+infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow
+offload=True
+
+# =================== Start RL training ===================
+"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \
+ --config-path=config \
+ --config-name="dapo_fsdp_config.yaml" \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ algorithm.filter_groups.enable=${enable_filter_groups} \
+ algorithm.filter_groups.metric=${filter_groups_metric} \
+ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
+ data.train_files="$train_files" \
+ data.val_files="$test_files" \
+ data.prompt_key=prompt \
+ data.truncation='right' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ data.gen_batch_size=${gen_prompt_bsz} \
+ actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \
+ actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=60000 \
+ actor_rollout_ref.actor.strategy="fsdp2" \
+ actor_rollout_ref.actor.optim.lr=5e-7 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.optim.warmup_style=constant \
+ actor_rollout_ref.actor.optim.min_lr_ratio=0. \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
+ actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
+ actor_rollout_ref.actor.entropy_checkpointing=True \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \
+ actor_rollout_ref.rollout.disable_log_stats=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.enable_prefix_caching=True \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\
+ actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.model.path=$BASE_MODEL \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.rollout.multi_turn.enable=False \
+ actor_rollout_ref.rollout.mode="sync" \
+ +actor_rollout_ref.model.override_config.attention_dropout=0. \
+ +actor_rollout_ref.model.override_config.embd_pdrop=0. \
+ +actor_rollout_ref.model.override_config.resid_pdrop=0. \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.model.enable_activation_offload=${offload} \
+ actor_rollout_ref.model.use_liger=True \
+ reward_model.reward_manager=async_multi_process \
+ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
+ reward_model.overlong_buffer.len=${overlong_buffer_len} \
+ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
+ +reward_model.reward_kwargs.num_processes=128 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name=${WANDB_PROJECT} \
+ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=$worker_num \
+ trainer.save_freq=5 \
+ trainer.test_freq=5 \
+ trainer.total_epochs=5 \
+ trainer.resume_mode=auto \
+ trainer.max_actor_ckpt_to_keep=3
+ # trainer.log_val_generations=50
diff --git a/scripts/train/k2p/k2p_hero_grpo_stage2_profiling.sh b/scripts/train/k2p/k2p_hero_grpo_stage2_profiling.sh
new file mode 100644
index 000000000..5586bf1ad
--- /dev/null
+++ b/scripts/train/k2p/k2p_hero_grpo_stage2_profiling.sh
@@ -0,0 +1,377 @@
+#!/bin/bash
+#SBATCH --job-name=grpo-stage2-k2pRL-dataMix2-profiling
+#SBATCH --nodes=64
+#SBATCH --ntasks=64
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=96
+#SBATCH --mem=0
+#SBATCH --output=slurm/%x-%j.log
+#SBATCH --error=slurm/%x-%j.log
+#SBATCH --exclusive
+#SBATCH --time=720:00:00
+#SBATCH --partition=higherprio
+#SBATCH --exclude=azure-uk-hpc-H200-instance-337
+
+# commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410]
+# job name: grpo-stage2-k2pRL-easy50k-7domains
+# job name: grpo-k2p-newFiltered-64k-fullData-finalInstruct
+
+# =================== Frequently Used Variables ===================
+RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch
+# export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+# export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty
+export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-036:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-058:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty
+
+# =================== Cluster Environment ===================
+export CONDA_BIN_PATH=/lustrefs/users/varad.pimpalkhute/anaconda3/envs/sync-rl-v5/bin/
+export ROCR_VISIBLE_DEVICES=None
+export NCCL_TIMEOUT_SECONDS=4800000
+export RAY_memory_usage_threshold=0.95 # Increase Ray memory threshold before killing workers
+export OMPI_MCA_coll_hcoll_enable=0 \
+TORCH_NCCL_ENABLE_MONITORING=0 \
+CUDA_DEVICE_ORDER=PCI_BUS_ID \
+NCCL_SOCKET_IFNAME=eth0 \
+UCX_TLS=rc \
+UCX_NET_DEVICES=mlx5_ib0:1 \
+NCCL_DEBUG=WARN \
+NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \
+NCCL_IB_PCI_RELAXED_ORDERING=1 \
+NCCL_IB_QPS_PER_CONNECTION=4 \
+NCCL_IGNORE_CPU_AFFINITY=1 \
+NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \
+NCCL_PXN_DISABLE=1 \
+NCCL_MIN_NCHANNELS=32 \
+SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \
+SHARP_COLL_ENABLE_SAT=1 \
+SHARP_COLL_LOG_LEVEL=3 \
+SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \
+NCCL_COLLNET_ENABLE=1 \
+NCCL_NVLS_ENABLE=0
+
+
+# Get the list of allocated nodes
+nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") )
+echo "Nodes to check: ${nodes[@]}"
+
+# We'll track PIDs so we can wait on them and detect errors
+declare -A pids
+export head_node=${nodes[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+port=6379
+address_head=$head_node_ip:$port
+
+export worker_num=$SLURM_NNODES
+export HYDRA_FULL_ERROR=1
+export VLLM_USE_V1=1
+
+# =================== Data Mixture ===================
+
+# Training Data Configuration
+DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_2"
+train_file_list=()
+id_val_file_list=()
+
+iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet"
+
+# List of datasets to include (filename only)
+# Comment out lines to exclude specific datasets
+dataset_names=(
+ "codegen__deduped_leetcode2k_2.4k.parquet"
+ "codegen__deduped_livecodebench_599.parquet"
+ "codegen__deduped_primeintellect_9.6k.parquet"
+ "codegen__deduped_taco_11.1k.parquet"
+ "ifbench__fixed_85.6k.parquet"
+ "simulation__codeio_fixed_12.1k.parquet"
+ "logic__arcagi1_297.parquet"
+ "logic__arcagi2_653.parquet"
+ "logic__barc_3.4k.parquet"
+ "logic__graph_logical_dataset_1.4k.parquet"
+ "logic__ordering_puzzle_dataset_2.9k.parquet"
+ "logic__reasoning_gym_40.6k.parquet"
+ "logic__synlogic_12.1k.parquet"
+ "logic__zebra_puzzle_dataset_5.0k.parquet"
+ "math__combined_118.2k.part1.parquet"
+ "math__combined_118.2k.part2.parquet"
+ "omni_math_4.43k.parquet"
+ "stem__nemotron_13.3k.parquet"
+ "stem__web_31.7k.parquet"
+ "table__hitab_7.4k.parquet"
+ "table__multihier_2.9k.parquet"
+)
+
+echo "Collecting training files from ${DATA_MIX_DIR}..."
+
+# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions" "easy_questions"
+for dataset in "${dataset_names[@]}"; do
+ for subdir in "main_questions" "131k_context_questions" "impossible_questions"; do
+ file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+ if [ -f "$file_path" ]; then
+ echo "Adding: $file_path"
+ train_file_list+=("'$file_path'")
+ fi
+ done
+done
+
+# for dataset in "${dataset_names[@]}"; do
+# for subdir in "131k_context_questions"; do
+# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+# if [ -f "$file_path" ]; then
+# echo "Adding: $file_path"
+# id_val_file_list+=("'$file_path'")
+# fi
+# done
+# done
+# id_val_file_list+=("'$iq400_path'")
+
+# Join with comma to form Python list string
+IFS=,
+train_files="[${train_file_list[*]}]"
+# id_val_files="[${id_val_file_list[*]}]"
+unset IFS
+
+echo "Total training files found: ${#train_file_list[@]}"
+# echo "Total ID validation files found: ${#id_val_file_list[@]}"
+
+# Test Data Configuration
+TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len
+# Math (test)
+math_test_path=${TEST_DATA_DIR}/math__math_500.parquet
+aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet
+aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet
+amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet
+
+# Code (test)
+humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet
+mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet
+livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet
+
+# Logic (test)
+zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet
+reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet
+synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet
+arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet
+# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet
+# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet
+
+# Table (test)
+multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet
+hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet
+
+# Stem (test)
+nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet
+gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet
+supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet
+
+# Instruction follow (test)
+if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet
+if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet
+
+# Focused data mixture (math, code, stem)
+# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']"
+# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']"
+
+# Full data mixture (uncomment to use)
+test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}',
+
+
+# =================== Model ===================
+# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k)
+BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k
+# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface
+
+# =================== Logging ===================
+WANDB_PROJECT=k2plus_rl
+WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/}
+# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491"
+
+# If RESUME_CKPT_DIR is not empty, resume from the checkpoint
+if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then
+ WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME"
+fi
+
+
+# =================== Ray start ===================
+# ray stop at all nodes
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop
+
+sleep 10
+# Remove existing Ray cluster
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster
+
+# Start Ray head node
+srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block &
+
+sleep 10
+
+# Start Ray worker nodes
+for ((i = 1; i < worker_num; i++)); do
+ node_i=${nodes[$i]}
+ echo "Starting WORKER $i at $node_i"
+ srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --address "$address_head" \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block &
+done
+sleep 10
+
+
+# =================== RL Config ===================
+# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline.
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=0.28
+
+max_prompt_length=$((1024 * 4))
+max_response_length=$((1024 * 64))
+enable_overlong_buffer=False
+overlong_buffer_len=$((1024 * 12))
+overlong_penalty_factor=1.0
+
+loss_agg_mode="token-mean"
+rollout_dtype="float16"
+
+enable_filter_groups=False
+filter_groups_metric=acc
+max_num_gen_batches=10
+train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n
+gen_prompt_bsz=$((train_prompt_bsz * 1))
+n_resp_per_prompt=16
+train_prompt_mini_bsz=256 # model grad update batchsize
+
+# Algorithm
+temperature=1.2
+val_temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+
+# Training config
+sp_size=16 # Disable sequence parallelism
+gen_tp=4
+gen_max_num_seqs=256 # REDUCED from 1024 to fix OOM with long sequences
+infer_micro_batch_size=null
+train_micro_batch_size=null
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow
+infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow
+offload=True
+
+# =================== Start RL training ===================
+"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \
+ --config-path=config \
+ --config-name="dapo_fsdp_config.yaml" \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ algorithm.filter_groups.enable=${enable_filter_groups} \
+ algorithm.filter_groups.metric=${filter_groups_metric} \
+ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
+ data.train_files="$train_files" \
+ data.val_files="$test_files" \
+ data.prompt_key=prompt \
+ data.truncation='right' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ data.gen_batch_size=${gen_prompt_bsz} \
+ actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \
+ actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
+ actor_rollout_ref.actor.strategy="fsdp2" \
+ actor_rollout_ref.actor.optim.lr=5e-7 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.optim.warmup_style=constant \
+ actor_rollout_ref.actor.optim.min_lr_ratio=0. \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
+ actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
+ actor_rollout_ref.actor.entropy_checkpointing=True \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \
+ actor_rollout_ref.rollout.disable_log_stats=False \
+ actor_rollout_ref.rollout.enforce_eager=True \
+ actor_rollout_ref.rollout.enable_prefix_caching=False \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\
+ actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.model.path=$BASE_MODEL \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.rollout.multi_turn.enable=False \
+ actor_rollout_ref.rollout.mode="sync" \
+ +actor_rollout_ref.model.override_config.attention_dropout=0. \
+ +actor_rollout_ref.model.override_config.embd_pdrop=0. \
+ +actor_rollout_ref.model.override_config.resid_pdrop=0. \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.model.enable_activation_offload=${offload} \
+ actor_rollout_ref.model.use_liger=True \
+ reward_model.reward_manager=async_multi_process \
+ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
+ reward_model.overlong_buffer.len=${overlong_buffer_len} \
+ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
+ +reward_model.reward_kwargs.num_processes=48 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name=${WANDB_PROJECT} \
+ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \
+ trainer.val_before_train=False \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=$worker_num \
+ trainer.save_freq=5 \
+ trainer.test_freq=5 \
+ trainer.total_epochs=5 \
+ trainer.resume_mode=auto \
+ trainer.max_actor_ckpt_to_keep=3 \
+ global_profiler.tool=torch \
+ global_profiler.steps=[1,2,3,4,5,6,7,8,9] \
+ actor_rollout_ref.actor.profiler.enable=True \
+ actor_rollout_ref.actor.profiler.all_ranks=True \
+ actor_rollout_ref.rollout.profiler.enable=True \
+ actor_rollout_ref.rollout.profiler.all_ranks=True
+
+
+ # trainer.log_val_generations=50
diff --git a/scripts/train/k2p_hero_cispo.sh b/scripts/train/k2p_hero_cispo.sh
new file mode 100644
index 000000000..25e4b338a
--- /dev/null
+++ b/scripts/train/k2p_hero_cispo.sh
@@ -0,0 +1,327 @@
+#!/bin/bash
+#SBATCH --job-name=cispo-focused-k2p-finalInstruct-temp1.0-wOmni-fix2
+#SBATCH --nodes=32
+#SBATCH --ntasks=32
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=96
+#SBATCH --mem=0
+#SBATCH --output=slurm/%x-%j.log
+#SBATCH --error=slurm/%x-%j.log
+#SBATCH --exclusive
+#SBATCH --time=720:00:00
+#SBATCH --partition=main
+#SBATCH --exclude=azure-uk-hpc-H200-instance-114,azure-uk-hpc-H200-instance-394
+
+# =================== Frequently Used Variables ===================
+RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch
+export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-009:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+
+# =================== Cluster Environment ===================
+export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/
+export ROCR_VISIBLE_DEVICES=None
+export NCCL_TIMEOUT_SECONDS=4800000
+export OMPI_MCA_coll_hcoll_enable=0 \
+TORCH_NCCL_ENABLE_MONITORING=0 \
+CUDA_DEVICE_ORDER=PCI_BUS_ID \
+NCCL_SOCKET_IFNAME=eth0 \
+UCX_TLS=rc \
+UCX_NET_DEVICES=mlx5_ib0:1 \
+NCCL_DEBUG=WARN \
+NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \
+NCCL_IB_PCI_RELAXED_ORDERING=1 \
+NCCL_IB_QPS_PER_CONNECTION=4 \
+NCCL_IGNORE_CPU_AFFINITY=1 \
+NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \
+NCCL_PXN_DISABLE=1 \
+NCCL_MIN_NCHANNELS=32 \
+SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \
+SHARP_COLL_ENABLE_SAT=1 \
+SHARP_COLL_LOG_LEVEL=3 \
+SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \
+NCCL_COLLNET_ENABLE=1
+
+# Get the list of allocated nodes
+nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") )
+echo "Nodes to check: ${nodes[@]}"
+
+# We'll track PIDs so we can wait on them and detect errors
+declare -A pids
+export head_node=${nodes[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+port=6379
+address_head=$head_node_ip:$port
+
+export worker_num=$SLURM_NNODES
+export HYDRA_FULL_ERROR=1
+export VLLM_USE_V1=1
+
+# =================== Data Mixture ===================
+SHARED_DATA_PATH=/lustrefs/users/haonan.li/data/k2
+MATH_DATA_PATH=/lustrefs/users/zhuojun.cheng/vpim/guru_data/train/postprocessed_dedup_am_semantic_filtered_0.05_0.94_thresh_ratio0.5_sample1.0_balanced_step2
+TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_1_datamix_6
+TEST_DATA_DIR=${SHARED_DATA_PATH}/test_12k_len
+
+# Math (train)
+math_train_path1=${MATH_DATA_PATH}/math__combined_118.2k.part1_scored.parquet
+math_train_path2=${MATH_DATA_PATH}/math__combined_118.2k.part2_scored.parquet
+# math_train_path1=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet
+# math_train_path2=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet
+# Math (test)
+math_test_path=${TEST_DATA_DIR}/math__math_500.parquet
+aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet
+aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet
+amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet
+
+# Code (train)
+leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet
+livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet
+primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet
+taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet
+# Code (test)
+humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet
+mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet
+livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet
+
+# Logic (train)
+arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet
+arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet
+barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet
+graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet
+ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet
+zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet
+reasoninggym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet
+synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet
+# Logic (test)
+zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet
+reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet
+synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet
+arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet
+# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet
+# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet
+
+
+# Simulation (train)
+codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet
+# Simulation (test)
+# codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500_sampled_200.parquet
+
+# Table (train)
+hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet
+multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet
+# Table (test)
+multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet
+hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet
+
+# Stem (train)
+webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet
+nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet
+# Stem (test)
+nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet
+gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet
+supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet
+
+# Instruction follow (train)
+if_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet
+
+if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet
+if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet
+
+# Focused data mixture (math, code, stem)
+train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']"
+test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']"
+
+# Full data mixture (uncomment to use)
+# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${reasoninggym_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}','${nemotron_train_path}','${if_train_path}']" # '${synlogic_train_path}',
+# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # '${synlogic_test_path}',
+
+
+# =================== Model ===================
+# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT
+BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k)
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k
+
+# =================== Logging ===================
+WANDB_PROJECT=k2plus_rl
+WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/}
+
+# If RESUME_CKPT_DIR is not empty, resume from the checkpoint
+if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then
+ WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME"
+fi
+
+
+# =================== Ray start ===================
+# ray stop at all nodes
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop
+
+sleep 10
+# Remove existing Ray cluster
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster
+
+# Start Ray head node
+srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block &
+
+sleep 10
+
+# Start Ray worker nodes
+for ((i = 1; i < worker_num; i++)); do
+ node_i=${nodes[$i]}
+ echo "Starting WORKER $i at $node_i"
+ srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --address "$address_head" \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block &
+done
+sleep 10
+
+
+# =================== RL Config ===================
+# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline.
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=2.0
+
+max_prompt_length=$((1024 * 4))
+max_response_length=$((1024 * 32))
+enable_overlong_buffer=False
+overlong_buffer_len=$((1024 * 12))
+overlong_penalty_factor=1.0
+
+loss_mode="cispo" # Default is 'vanilla' which is equivalent to PPO;
+loss_agg_mode="token-mean"
+rollout_dtype="float16"
+
+enable_filter_groups=False
+filter_groups_metric=acc
+max_num_gen_batches=10
+train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n
+gen_prompt_bsz=$((train_prompt_bsz * 1))
+n_resp_per_prompt=16
+train_prompt_mini_bsz=256 # model grad update batchsize
+
+# Algorithm
+temperature=1.0
+val_temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+
+# Training config
+sp_size=16 # Reduced from 32 to reduce memory pressure
+gen_tp=4
+gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure
+infer_micro_batch_size=null
+train_micro_batch_size=null
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow
+infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow
+offload=True
+
+# =================== Start RL training ===================
+"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \
+ --config-path=config \
+ --config-name="dapo_fsdp_config.yaml" \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ algorithm.filter_groups.enable=${enable_filter_groups} \
+ algorithm.filter_groups.metric=${filter_groups_metric} \
+ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
+ data.train_files="$train_files" \
+ data.val_files="$test_files" \
+ data.prompt_key=prompt \
+ data.truncation='right' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ data.gen_batch_size=${gen_prompt_bsz} \
+ actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \
+ actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
+ actor_rollout_ref.actor.strategy="fsdp2" \
+ actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
+ actor_rollout_ref.actor.optim.lr=5e-7 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.optim.warmup_style=constant \
+ actor_rollout_ref.actor.optim.min_lr_ratio=0. \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
+ actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
+ actor_rollout_ref.actor.entropy_checkpointing=True \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \
+ actor_rollout_ref.rollout.disable_log_stats=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.enable_prefix_caching=True \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\
+ actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.model.path=$BASE_MODEL \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.rollout.multi_turn.enable=False \
+ actor_rollout_ref.rollout.mode="sync" \
+ +actor_rollout_ref.model.override_config.attention_dropout=0. \
+ +actor_rollout_ref.model.override_config.embd_pdrop=0. \
+ +actor_rollout_ref.model.override_config.resid_pdrop=0. \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.model.enable_activation_offload=True \
+ actor_rollout_ref.model.use_liger=True \
+ reward_model.reward_manager=async_multi_process \
+ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
+ reward_model.overlong_buffer.len=${overlong_buffer_len} \
+ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name=${WANDB_PROJECT} \
+ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \
+ trainer.val_before_train=False \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=$worker_num \
+ trainer.save_freq=10 \
+ trainer.test_freq=5 \
+ trainer.total_epochs=5 \
+ trainer.log_val_generations=50 \
+ trainer.resume_mode=auto \
+ trainer.max_actor_ckpt_to_keep=3
\ No newline at end of file
diff --git a/scripts/train/k2p_hero_cispo_newData.sh b/scripts/train/k2p_hero_cispo_newData.sh
new file mode 100644
index 000000000..8ffd6d2f5
--- /dev/null
+++ b/scripts/train/k2p_hero_cispo_newData.sh
@@ -0,0 +1,362 @@
+#!/bin/bash
+#SBATCH --job-name=cispo-k2p-newFiltered-64k-mainQs-finalInstruct
+#SBATCH --nodes=64
+#SBATCH --ntasks=64
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=96
+#SBATCH --mem=0
+#SBATCH --output=slurm/%x-%j.log
+#SBATCH --error=slurm/%x-%j.log
+#SBATCH --exclusive
+#SBATCH --time=720:00:00
+#SBATCH --partition=higherprio
+#SBATCH --exclude=azure-uk-hpc-H200-instance-[347-410]
+
+# =================== Frequently Used Variables ===================
+RESUME_CKPT_DIR_NAME="cispo-k2p-newFiltered-64k-mainQs-finalInstruct-406746" # Fill in the checkpoint directory name to resume from, otherwise from scratch
+export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-286:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty
+
+# =================== Cluster Environment ===================
+export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/
+export ROCR_VISIBLE_DEVICES=None
+export NCCL_TIMEOUT_SECONDS=4800000
+export OMPI_MCA_coll_hcoll_enable=0 \
+TORCH_NCCL_ENABLE_MONITORING=0 \
+CUDA_DEVICE_ORDER=PCI_BUS_ID \
+NCCL_SOCKET_IFNAME=eth0 \
+TP_SOCKET_IFNAME=eth0 \
+GLOO_SOCKET_IFNAME=eth0 \
+UCX_TLS=rc \
+UCX_NET_DEVICES=mlx5_ib0:1 \
+NCCL_DEBUG=WARN \
+NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \
+NCCL_IB_PCI_RELAXED_ORDERING=1 \
+NCCL_IB_QPS_PER_CONNECTION=4 \
+NCCL_IGNORE_CPU_AFFINITY=1 \
+NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \
+NCCL_PXN_DISABLE=1 \
+NCCL_MIN_NCHANNELS=32 \
+SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \
+SHARP_COLL_ENABLE_SAT=1 \
+SHARP_COLL_LOG_LEVEL=3 \
+SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \
+NCCL_COLLNET_ENABLE=1
+
+# Get the list of allocated nodes
+nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") )
+echo "Nodes to check: ${nodes[@]}"
+
+# We'll track PIDs so we can wait on them and detect errors
+declare -A pids
+export head_node=${nodes[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+port=6379
+address_head=$head_node_ip:$port
+
+export worker_num=$SLURM_NNODES
+export HYDRA_FULL_ERROR=1
+export VLLM_USE_V1=1
+
+# =================== Data Mixture ===================
+
+# Training Data Configuration
+DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_1"
+train_file_list=()
+
+iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet"
+
+# List of datasets to include (filename only)
+# Comment out lines to exclude specific datasets
+dataset_names=(
+ "codegen__deduped_leetcode2k_2.4k.parquet"
+ "codegen__deduped_livecodebench_599.parquet"
+ "codegen__deduped_primeintellect_9.6k.parquet"
+ "codegen__deduped_taco_11.1k.parquet"
+ "ifbench__fixed_85.6k.parquet"
+ "logic__arcagi1_297.parquet"
+ "logic__arcagi2_653.parquet"
+ "logic__barc_3.4k.parquet"
+ "logic__graph_logical_dataset_1.4k.parquet"
+ "logic__ordering_puzzle_dataset_2.9k.parquet"
+ "logic__reasoning_gym_40.6k.parquet"
+ "logic__synlogic_12.1k.parquet"
+ "logic__zebra_puzzle_dataset_5.0k.parquet"
+ "math__combined_118.2k.part1.parquet"
+ "math__combined_118.2k.part2.parquet"
+ "omni_math_4.43k_dedup.parquet"
+ "simulation__codeio_fixed_12.1k.parquet"
+ "stem__nemotron_13.3k.parquet"
+ "stem__web_31.7k.parquet"
+ "table__hitab_7.4k.parquet"
+ "table__multihier_2.9k.parquet"
+)
+
+echo "Collecting training files from ${DATA_MIX_DIR}..."
+
+# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions"
+for dataset in "${dataset_names[@]}"; do
+ for subdir in "main_questions"; do
+ file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+ if [ -f "$file_path" ]; then
+ echo "Adding: $file_path"
+ train_file_list+=("'$file_path'")
+ fi
+ done
+done
+
+# for dataset in "${dataset_names[@]}"; do
+# for subdir in "131k_context_questions"; do
+# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+# if [ -f "$file_path" ]; then
+# echo "Adding: $file_path"
+# id_val_file_list+=("'$file_path'")
+# fi
+# done
+# done
+# id_val_file_list+=("'$iq400_path'")
+
+# Join with comma to form Python list string
+IFS=,
+train_files="[${train_file_list[*]}]"
+# id_val_files="[${id_val_file_list[*]}]"
+unset IFS
+
+echo "Total training files found: ${#train_file_list[@]}"
+# echo "Total ID validation files found: ${#id_val_file_list[@]}"
+
+# Test Data Configuration
+TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len
+# Math (test)
+math_test_path=${TEST_DATA_DIR}/math__math_500.parquet
+aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet
+aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet
+amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet
+
+# Code (test)
+humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet
+mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet
+livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet
+
+# Logic (test)
+zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet
+reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet
+synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet
+arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet
+# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet
+# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet
+
+# Table (test)
+multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet
+hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet
+
+# Stem (test)
+nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet
+gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet
+supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet
+
+# Instruction follow (test)
+if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet
+if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet
+
+# Focused data mixture (math, code, stem)
+# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']"
+# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']"
+
+# Full data mixture (uncomment to use)
+test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${synlogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}'
+
+
+# =================== Model ===================
+# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT
+BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k)
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k
+# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/cispo-k2p-newFiltered-32k-mainQs-finalInstruct-406513/global_step_100/actor/huggingface
+
+# =================== Logging ===================
+WANDB_PROJECT=k2plus_rl
+WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/}
+# WANDB_EXPERIMENT_NAME="cispo-k2p-newFiltered-32k-mainQs-finalInstruct-406513"
+
+# If RESUME_CKPT_DIR is not empty, resume from the checkpoint
+if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then
+ WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME"
+fi
+
+
+# =================== Ray start ===================
+# ray stop at all nodes
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop
+
+sleep 10
+# Remove existing Ray cluster
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster
+
+# Start Ray head node
+srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block &
+
+sleep 10
+
+# Start Ray worker nodes
+for ((i = 1; i < worker_num; i++)); do
+ node_i=${nodes[$i]}
+ echo "Starting WORKER $i at $node_i"
+ srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --address "$address_head" \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block &
+done
+sleep 10
+
+
+# =================== RL Config ===================
+# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline.
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=2.0
+
+max_prompt_length=$((1024 * 4))
+max_response_length=$((1024 * 64))
+enable_overlong_buffer=False
+overlong_buffer_len=$((1024 * 12))
+overlong_penalty_factor=1.0
+
+loss_mode="cispo" #Default is 'vanilla' which is equivalent to PPO
+loss_agg_mode="token-mean"
+rollout_dtype="float16"
+
+enable_filter_groups=False
+filter_groups_metric=acc
+max_num_gen_batches=10
+train_prompt_bsz=128 # on-policy model update batchsize: train_prompt_bsz * rollout.n
+gen_prompt_bsz=$((train_prompt_bsz * 1))
+n_resp_per_prompt=16
+train_prompt_mini_bsz=128 # model grad update batchsize
+
+# Algorithm
+temperature=1.2
+val_temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+
+# Training config
+sp_size=16 # Reduced from 32 to reduce memory pressure
+gen_tp=4
+gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure
+infer_micro_batch_size=null
+train_micro_batch_size=null
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow
+infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow
+offload=True
+
+# =================== Start RL training ===================
+"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \
+ --config-path=config \
+ --config-name="dapo_fsdp_config.yaml" \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ algorithm.filter_groups.enable=${enable_filter_groups} \
+ algorithm.filter_groups.metric=${filter_groups_metric} \
+ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
+ data.train_files="$train_files" \
+ data.val_files="$test_files" \
+ data.prompt_key=prompt \
+ data.truncation='right' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ data.gen_batch_size=${gen_prompt_bsz} \
+ actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \
+ actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
+ actor_rollout_ref.actor.strategy="fsdp2" \
+ actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
+ actor_rollout_ref.actor.optim.lr=5e-7 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.optim.warmup_style=constant \
+ actor_rollout_ref.actor.optim.min_lr_ratio=0. \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
+ actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
+ actor_rollout_ref.actor.entropy_checkpointing=True \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \
+ actor_rollout_ref.rollout.disable_log_stats=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.enable_prefix_caching=True \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\
+ actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.model.path=$BASE_MODEL \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.rollout.multi_turn.enable=False \
+ actor_rollout_ref.rollout.mode="sync" \
+ +actor_rollout_ref.model.override_config.attention_dropout=0. \
+ +actor_rollout_ref.model.override_config.embd_pdrop=0. \
+ +actor_rollout_ref.model.override_config.resid_pdrop=0. \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.model.enable_activation_offload=True \
+ actor_rollout_ref.model.use_liger=True \
+ reward_model.reward_manager=async_multi_process \
+ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
+ reward_model.overlong_buffer.len=${overlong_buffer_len} \
+ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name=${WANDB_PROJECT} \
+ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \
+ trainer.val_before_train=False \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=$worker_num \
+ trainer.save_freq=10 \
+ trainer.test_freq=5 \
+ trainer.total_epochs=5 \
+ trainer.log_val_generations=50 \
+ trainer.resume_mode=auto \
+ trainer.max_actor_ckpt_to_keep=3
+ # data.id_val_files="$id_val_files" \
\ No newline at end of file
diff --git a/scripts/train/k2p_hero_grpo.sh b/scripts/train/k2p_hero_grpo.sh
new file mode 100644
index 000000000..5908854e2
--- /dev/null
+++ b/scripts/train/k2p_hero_grpo.sh
@@ -0,0 +1,326 @@
+#!/bin/bash
+#SBATCH --job-name=grpo-k2p-32k264k-stage2-focused
+#SBATCH --nodes=64
+#SBATCH --ntasks=64
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=96
+#SBATCH --mem=0
+#SBATCH --output=slurm/%x-%j.log
+#SBATCH --error=slurm/%x-%j.log
+#SBATCH --exclusive
+#SBATCH --time=720:00:00
+#SBATCH --partition=main
+#SBATCH --exclude=azure-uk-hpc-H200-instance-114,azure-uk-hpc-H200-instance-394
+
+# =================== Frequently Used Variables ===================
+RESUME_CKPT_DIR_NAME="grpo-k2p-32k264k-stage2-focused-404083" # Fill in the checkpoint directory name to resume from, otherwise from scratch
+export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-009:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+
+# =================== Cluster Environment ===================
+export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/
+export ROCR_VISIBLE_DEVICES=None
+export NCCL_TIMEOUT_SECONDS=4800000
+export OMPI_MCA_coll_hcoll_enable=0 \
+TORCH_NCCL_ENABLE_MONITORING=0 \
+CUDA_DEVICE_ORDER=PCI_BUS_ID \
+NCCL_SOCKET_IFNAME=eth0 \
+UCX_TLS=rc \
+UCX_NET_DEVICES=mlx5_ib0:1 \
+NCCL_DEBUG=WARN \
+NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \
+NCCL_IB_PCI_RELAXED_ORDERING=1 \
+NCCL_IB_QPS_PER_CONNECTION=4 \
+NCCL_IGNORE_CPU_AFFINITY=1 \
+NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \
+NCCL_PXN_DISABLE=1 \
+NCCL_MIN_NCHANNELS=32 \
+SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \
+SHARP_COLL_ENABLE_SAT=1 \
+SHARP_COLL_LOG_LEVEL=3 \
+SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \
+NCCL_COLLNET_ENABLE=1
+
+# Get the list of allocated nodes
+nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") )
+echo "Nodes to check: ${nodes[@]}"
+
+# We'll track PIDs so we can wait on them and detect errors
+declare -A pids
+export head_node=${nodes[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+port=6379
+address_head=$head_node_ip:$port
+
+export worker_num=$SLURM_NNODES
+export HYDRA_FULL_ERROR=1
+export VLLM_USE_V1=1
+
+# =================== Data Mixture ===================
+SHARED_DATA_PATH=/lustrefs/users/haonan.li/data/k2
+MATH_DATA_PATH=/lustrefs/users/zhuojun.cheng/vpim/guru_data/train/postprocessed_dedup_am_semantic_filtered_0.05_0.94_thresh_ratio0.5_sample1.0_balanced_step2
+TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_1_datamix_6
+TEST_DATA_DIR=${SHARED_DATA_PATH}/test_12k_len
+
+# Math (train)
+math_train_path1=${MATH_DATA_PATH}/math__combined_118.2k.part1_scored.parquet
+math_train_path2=${MATH_DATA_PATH}/math__combined_118.2k.part2_scored.parquet
+# math_train_path1=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet
+# math_train_path2=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet
+# Math (test)
+math_test_path=${TEST_DATA_DIR}/math__math_500.parquet
+aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet
+aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet
+amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet
+
+# Code (train)
+leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet
+livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet
+primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet
+taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet
+# Code (test)
+humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet
+mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet
+livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet
+
+# Logic (train)
+arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet
+arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet
+barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet
+graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet
+ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet
+zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet
+reasoninggym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet
+synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet
+# Logic (test)
+zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet
+reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet
+synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet
+arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet
+# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet
+# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet
+
+
+# Simulation (train)
+codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet
+# Simulation (test)
+# codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500_sampled_200.parquet
+
+# Table (train)
+hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet
+multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet
+# Table (test)
+multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet
+hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet
+
+# Stem (train)
+webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet
+nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet
+# Stem (test)
+nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet
+gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet
+supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet
+
+# Instruction follow (train)
+if_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet
+
+if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet
+if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet
+
+# Focused data mixture (math, code, stem)
+train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']"
+test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']"
+
+# Full data mixture (uncomment to use)
+# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${reasoninggym_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}','${nemotron_train_path}','${if_train_path}']" # '${synlogic_train_path}',
+# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # '${synlogic_test_path}',
+
+
+# =================== Model ===================
+# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT
+BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k)
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k
+# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-focused-k2p-finalInstruct-temp1.2-wOmni-fix2-403906/global_step_300/actor/huggingface
+
+# =================== Logging ===================
+WANDB_PROJECT=k2plus_rl
+WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/}
+
+# If RESUME_CKPT_DIR is not empty, resume from the checkpoint
+if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then
+ WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME"
+fi
+
+
+# =================== Ray start ===================
+# ray stop at all nodes
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop
+
+sleep 10
+# Remove existing Ray cluster
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster
+
+# Start Ray head node
+srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block &
+
+sleep 10
+
+# Start Ray worker nodes
+for ((i = 1; i < worker_num; i++)); do
+ node_i=${nodes[$i]}
+ echo "Starting WORKER $i at $node_i"
+ srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --address "$address_head" \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block &
+done
+sleep 10
+
+
+# =================== RL Config ===================
+# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline.
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=0.28
+
+max_prompt_length=$((1024 * 4))
+max_response_length=$((1024 * 64))
+enable_overlong_buffer=False
+overlong_buffer_len=$((1024 * 12))
+overlong_penalty_factor=1.0
+
+loss_agg_mode="token-mean"
+rollout_dtype="float16"
+
+enable_filter_groups=False
+filter_groups_metric=acc
+max_num_gen_batches=10
+train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n
+gen_prompt_bsz=$((train_prompt_bsz * 1))
+n_resp_per_prompt=16
+train_prompt_mini_bsz=256 # model grad update batchsize
+
+# Algorithm
+temperature=1.2
+val_temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+
+# Training config
+sp_size=16 # Reduced from 32 to reduce memory pressure
+gen_tp=4
+gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure
+infer_micro_batch_size=null
+train_micro_batch_size=null
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow
+infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow
+offload=True
+
+# =================== Start RL training ===================
+"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \
+ --config-path=config \
+ --config-name="dapo_fsdp_config.yaml" \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ algorithm.filter_groups.enable=${enable_filter_groups} \
+ algorithm.filter_groups.metric=${filter_groups_metric} \
+ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
+ data.train_files="$train_files" \
+ data.val_files="$test_files" \
+ data.prompt_key=prompt \
+ data.truncation='right' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ data.gen_batch_size=${gen_prompt_bsz} \
+ actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \
+ actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
+ actor_rollout_ref.actor.strategy="fsdp2" \
+ actor_rollout_ref.actor.optim.lr=5e-7 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.optim.warmup_style=constant \
+ actor_rollout_ref.actor.optim.min_lr_ratio=0. \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
+ actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
+ actor_rollout_ref.actor.entropy_checkpointing=True \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \
+ actor_rollout_ref.rollout.disable_log_stats=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.enable_prefix_caching=True \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\
+ actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.model.path=$BASE_MODEL \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.rollout.multi_turn.enable=False \
+ actor_rollout_ref.rollout.mode="sync" \
+ +actor_rollout_ref.model.override_config.attention_dropout=0. \
+ +actor_rollout_ref.model.override_config.embd_pdrop=0. \
+ +actor_rollout_ref.model.override_config.resid_pdrop=0. \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.model.enable_activation_offload=True \
+ actor_rollout_ref.model.use_liger=True \
+ reward_model.reward_manager=async_multi_process \
+ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
+ reward_model.overlong_buffer.len=${overlong_buffer_len} \
+ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name=${WANDB_PROJECT} \
+ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \
+ trainer.val_before_train=False \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=$worker_num \
+ trainer.save_freq=10 \
+ trainer.test_freq=5 \
+ trainer.total_epochs=5 \
+ trainer.log_val_generations=50 \
+ trainer.resume_mode=auto \
+ trainer.max_actor_ckpt_to_keep=3
\ No newline at end of file
diff --git a/scripts/train/k2p_hero_grpo_newData.sh b/scripts/train/k2p_hero_grpo_newData.sh
new file mode 100644
index 000000000..b0a4f1b5e
--- /dev/null
+++ b/scripts/train/k2p_hero_grpo_newData.sh
@@ -0,0 +1,365 @@
+#!/bin/bash
+#SBATCH --job-name=grpo-stage1-k2pRL-mainMathOnly
+#SBATCH --nodes=32
+#SBATCH --ntasks=32
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=96
+#SBATCH --mem=0
+#SBATCH --output=slurm/%x-%j.log
+#SBATCH --error=slurm/%x-%j.log
+#SBATCH --exclusive
+#SBATCH --time=720:00:00
+#SBATCH --partition=higherprio
+
+# commenting out... SBATCH --exclude=azure-uk-hpc-H200-instance-[043-060,249,347-410]
+# job name: grpo-stage2-k2pRL-easy50k-7domains
+# job name: grpo-k2p-newFiltered-64k-fullData-finalInstruct
+
+# =================== Frequently Used Variables ===================
+RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch
+export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-284:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty
+
+# =================== Cluster Environment ===================
+export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/
+export ROCR_VISIBLE_DEVICES=None
+export NCCL_TIMEOUT_SECONDS=4800000
+export OMPI_MCA_coll_hcoll_enable=0 \
+TORCH_NCCL_ENABLE_MONITORING=0 \
+CUDA_DEVICE_ORDER=PCI_BUS_ID \
+NCCL_SOCKET_IFNAME=eth0 \
+UCX_TLS=rc \
+UCX_NET_DEVICES=mlx5_ib0:1 \
+NCCL_DEBUG=WARN \
+NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \
+NCCL_IB_PCI_RELAXED_ORDERING=1 \
+NCCL_IB_QPS_PER_CONNECTION=4 \
+NCCL_IGNORE_CPU_AFFINITY=1 \
+NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \
+NCCL_PXN_DISABLE=1 \
+NCCL_MIN_NCHANNELS=32 \
+SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \
+SHARP_COLL_ENABLE_SAT=1 \
+SHARP_COLL_LOG_LEVEL=3 \
+SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \
+NCCL_COLLNET_ENABLE=1 \
+NCCL_NVLS_ENABLE=0
+
+
+# Get the list of allocated nodes
+nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") )
+echo "Nodes to check: ${nodes[@]}"
+
+# We'll track PIDs so we can wait on them and detect errors
+declare -A pids
+export head_node=${nodes[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+port=6379
+address_head=$head_node_ip:$port
+
+export worker_num=$SLURM_NNODES
+export HYDRA_FULL_ERROR=1
+export VLLM_USE_V1=1
+
+# =================== Data Mixture ===================
+
+# Training Data Configuration
+DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_1"
+train_file_list=()
+id_val_file_list=()
+
+iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet"
+
+# List of datasets to include (filename only)
+# Comment out lines to exclude specific datasets
+dataset_names=(
+ "math__combined_118.2k.part1.parquet"
+ "math__combined_118.2k.part2.parquet"
+ "omni_math_4.43k_dedup.parquet"
+)
+ # "stem__nemotron_13.3k.parquet"
+ # "stem__web_31.7k.parquet"
+ # "table__hitab_7.4k.parquet"
+ # "table__multihier_2.9k.parquet"
+# "codegen__deduped_leetcode2k_2.4k.parquet"
+# "codegen__deduped_livecodebench_599.parquet"
+# "codegen__deduped_primeintellect_9.6k.parquet"
+# "codegen__deduped_taco_11.1k.parquet"
+# "ifbench__fixed_85.6k.parquet"
+# "simulation__codeio_fixed_12.1k.parquet"
+# "logic__arcagi1_297.parquet"
+# "logic__arcagi2_653.parquet"
+# "logic__barc_3.4k.parquet"
+# "logic__graph_logical_dataset_1.4k.parquet"
+# "logic__ordering_puzzle_dataset_2.9k.parquet"
+# "logic__reasoning_gym_40.6k.parquet"
+# "logic__synlogic_12.1k.parquet"
+# "logic__zebra_puzzle_dataset_5.0k.parquet"
+
+echo "Collecting training files from ${DATA_MIX_DIR}..."
+
+# Search for each dataset in all subdirectories "impossible_questions" "131k_context_questions" "main_questions" "easy_questions"
+for dataset in "${dataset_names[@]}"; do
+ for subdir in "main_questions"; do
+ file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+ if [ -f "$file_path" ]; then
+ echo "Adding: $file_path"
+ train_file_list+=("'$file_path'")
+ fi
+ done
+done
+
+# for dataset in "${dataset_names[@]}"; do
+# for subdir in "131k_context_questions"; do
+# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+# if [ -f "$file_path" ]; then
+# echo "Adding: $file_path"
+# id_val_file_list+=("'$file_path'")
+# fi
+# done
+# done
+# id_val_file_list+=("'$iq400_path'")
+
+# Join with comma to form Python list string
+IFS=,
+train_files="[${train_file_list[*]}]"
+# id_val_files="[${id_val_file_list[*]}]"
+unset IFS
+
+echo "Total training files found: ${#train_file_list[@]}"
+# echo "Total ID validation files found: ${#id_val_file_list[@]}"
+
+# Test Data Configuration
+TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len
+# Math (test)
+math_test_path=${TEST_DATA_DIR}/math__math_500.parquet
+aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet
+aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet
+amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet
+
+# Code (test)
+humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet
+mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet
+livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet
+
+# Logic (test)
+zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet
+reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet
+synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet
+arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet
+# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet
+# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet
+
+# Table (test)
+multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet
+hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet
+
+# Stem (test)
+nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet
+gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet
+supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet
+
+# Instruction follow (test)
+if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet
+if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet
+
+# Focused data mixture (math, code, stem)
+# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']"
+# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']"
+
+# Full data mixture (uncomment to use)
+test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}']"
+# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}', '${synlogic_test_path}',
+
+
+# =================== Model ===================
+# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT
+BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k)
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k
+# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface
+
+# =================== Logging ===================
+WANDB_PROJECT=k2plus_rl
+WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/}
+# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491"
+
+# If RESUME_CKPT_DIR is not empty, resume from the checkpoint
+if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then
+ WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME"
+fi
+
+
+# =================== Ray start ===================
+# ray stop at all nodes
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop
+
+sleep 10
+# Remove existing Ray cluster
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster
+
+# Start Ray head node
+srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block &
+
+sleep 10
+
+# Start Ray worker nodes
+for ((i = 1; i < worker_num; i++)); do
+ node_i=${nodes[$i]}
+ echo "Starting WORKER $i at $node_i"
+ srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --address "$address_head" \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block &
+done
+sleep 10
+
+
+# =================== RL Config ===================
+# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline.
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=0.28
+
+max_prompt_length=$((1024 * 4))
+max_response_length=$((1024 * 32))
+enable_overlong_buffer=False
+overlong_buffer_len=$((1024 * 12))
+overlong_penalty_factor=1.0
+
+loss_agg_mode="token-mean"
+rollout_dtype="float16"
+
+enable_filter_groups=False
+filter_groups_metric=acc
+max_num_gen_batches=10
+train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n
+gen_prompt_bsz=$((train_prompt_bsz * 1))
+n_resp_per_prompt=16
+train_prompt_mini_bsz=256 # model grad update batchsize
+
+# Algorithm
+temperature=1.2
+val_temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+
+# Training config
+sp_size=16 # Reduced from 32 to reduce memory pressure
+gen_tp=4
+gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure
+infer_micro_batch_size=null
+train_micro_batch_size=null
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow
+infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow
+offload=True
+
+# =================== Start RL training ===================
+"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \
+ --config-path=config \
+ --config-name="dapo_fsdp_config.yaml" \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ algorithm.filter_groups.enable=${enable_filter_groups} \
+ algorithm.filter_groups.metric=${filter_groups_metric} \
+ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
+ data.train_files="$train_files" \
+ data.val_files="$test_files" \
+ data.prompt_key=prompt \
+ data.truncation='right' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ data.gen_batch_size=${gen_prompt_bsz} \
+ actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \
+ actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
+ actor_rollout_ref.actor.strategy="fsdp2" \
+ actor_rollout_ref.actor.optim.lr= 5e-7 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.optim.warmup_style=constant \
+ actor_rollout_ref.actor.optim.min_lr_ratio=0. \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
+ actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
+ actor_rollout_ref.actor.entropy_checkpointing=True \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \
+ actor_rollout_ref.rollout.disable_log_stats=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.enable_prefix_caching=True \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\
+ actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.model.path=$BASE_MODEL \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.rollout.multi_turn.enable=False \
+ actor_rollout_ref.rollout.mode="sync" \
+ +actor_rollout_ref.model.override_config.attention_dropout=0. \
+ +actor_rollout_ref.model.override_config.embd_pdrop=0. \
+ +actor_rollout_ref.model.override_config.resid_pdrop=0. \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.model.enable_activation_offload=${offload} \
+ actor_rollout_ref.model.use_liger=True \
+ reward_model.reward_manager=async_multi_process \
+ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
+ reward_model.overlong_buffer.len=${overlong_buffer_len} \
+ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
+ +reward_model.reward_kwargs.num_processes=64 \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name=${WANDB_PROJECT} \
+ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \
+ trainer.val_before_train=False \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=$worker_num \
+ trainer.save_freq=5 \
+ trainer.test_freq=5 \
+ trainer.total_epochs=5 \
+ trainer.resume_mode=auto \
+ trainer.max_actor_ckpt_to_keep=3
+ # trainer.log_val_generations=50
\ No newline at end of file
diff --git a/scripts/train/k2p_hero_grpo_newNewData.sh b/scripts/train/k2p_hero_grpo_newNewData.sh
new file mode 100644
index 000000000..a487eaa93
--- /dev/null
+++ b/scripts/train/k2p_hero_grpo_newNewData.sh
@@ -0,0 +1,371 @@
+#!/bin/bash
+#SBATCH --job-name=grpo-stage2-k2pRL-easy50k-7domains
+#SBATCH --nodes=64
+#SBATCH --ntasks=64
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=96
+#SBATCH --mem=0
+#SBATCH --output=slurm/%x-%j.log
+#SBATCH --error=slurm/%x-%j.log
+#SBATCH --exclusive
+#SBATCH --time=720:00:00
+#SBATCH --partition=higherprio
+#SBATCH --exclude=azure-uk-hpc-H200-instance-[028-049]
+
+# =================== Frequently Used Variables ===================
+RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch
+export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-004:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-025:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty
+
+# =================== Cluster Environment ===================
+export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/
+export ROCR_VISIBLE_DEVICES=None
+export NCCL_TIMEOUT_SECONDS=4800000
+export OMPI_MCA_coll_hcoll_enable=0 \
+TORCH_NCCL_ENABLE_MONITORING=0 \
+CUDA_DEVICE_ORDER=PCI_BUS_ID \
+NCCL_SOCKET_IFNAME=eth0 \
+UCX_TLS=rc \
+UCX_NET_DEVICES=mlx5_ib0:1 \
+NCCL_DEBUG=WARN \
+NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \
+NCCL_IB_PCI_RELAXED_ORDERING=1 \
+NCCL_IB_QPS_PER_CONNECTION=4 \
+NCCL_IGNORE_CPU_AFFINITY=1 \
+NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \
+NCCL_PXN_DISABLE=1 \
+NCCL_MIN_NCHANNELS=32 \
+SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \
+SHARP_COLL_ENABLE_SAT=1 \
+SHARP_COLL_LOG_LEVEL=3 \
+SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \
+NCCL_COLLNET_ENABLE=1
+
+# Get the list of allocated nodes
+nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") )
+echo "Nodes to check: ${nodes[@]}"
+
+# We'll track PIDs so we can wait on them and detect errors
+declare -A pids
+export head_node=${nodes[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+port=6379
+address_head=$head_node_ip:$port
+
+export worker_num=$SLURM_NNODES
+export HYDRA_FULL_ERROR=1
+export VLLM_USE_V1=1
+
+# =================== Data Mixture ===================
+
+# Training Data Configuration
+# DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_1"
+DATA_MIX_DIR="/lustrefs/users/haonan.li/Reasoning360/final/data-mixtures/easy/data_mix_seven_domains_20000"
+train_file_list=()
+id_val_file_list=()
+
+iq400_path="/lustrefs/users/taylor.killian/Reasoning360/data/guru_data/iq400_proxy.parquet"
+
+# List of datasets to include (filename only)
+# Comment out lines to exclude specific datasets
+dataset_names=(
+ "codegen__deduped_leetcode2k_2.4k_scored.parquet"
+ "codegen__deduped_livecodebench_599_scored.parquet"
+ "codegen__deduped_primeintellect_9.6k_scored.parquet"
+ "codegen__deduped_taco_11.1k_scored.parquet"
+ "ifbench__fixed_85.6k_scored.parquet"
+ "logic__arcagi1_297_scored.parquet"
+ "logic__arcagi2_653_scored.parquet"
+ "logic__barc_3.4k_scored.parquet"
+ "logic__graph_logical_dataset_1.4k_scored.parquet"
+ "logic__ordering_puzzle_dataset_2.9k_scored.parquet"
+ "logic__reasoning_gym_40.6k_scored.parquet"
+ "logic__synlogic_12.1k_scored.parquet"
+ "logic__zebra_puzzle_dataset_5.0k_scored.parquet"
+ "math__combined_118.2k.part1_scored.parquet"
+ "math__combined_118.2k.part2_scored.parquet"
+ "omni_math_4.43k_scored.parquet"
+ "simulation__codeio_fixed_12.1k_scored.parquet"
+ "stem__nemotron_13.3k_scored.parquet"
+ "stem__web_31.7k_scored.parquet"
+ "table__hitab_7.4k_scored.parquet"
+ "table__multihier_2.9k_scored.parquet"
+)
+# "omni_math_4.43k_dedup.parquet"
+
+echo "Collecting training files from ${DATA_MIX_DIR}/..."
+
+for dataset in "${dataset_names[@]}"; do
+ for subdir in "easy_questions" "main_questions" "131k_context_questions"; do
+ file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+
+ if [ -f "$file_path" ]; then
+ echo "Found: $file_path"
+ train_file_list+=("'$file_path'")
+
+ # --- NEW ELIF BLOCK ---
+ # Checks for the file with an extra "_scored" inserted before .parquet
+ # e.g., changes "file_scored.parquet" to "file_scored_scored.parquet"
+ elif [ -f "${DATA_MIX_DIR}/${subdir}/${dataset/.parquet/_scored.parquet}" ]; then
+ alt_file_path="${DATA_MIX_DIR}/${subdir}/${dataset/.parquet/_scored.parquet}"
+ echo "Found (with extra _scored): $alt_file_path"
+ train_file_list+=("'$alt_file_path'")
+ else
+ echo "Missing: $file_path"
+ fi
+ done
+done
+
+# for dataset in "${dataset_names[@]}"; do
+# for subdir in "131k_context_questions"; do
+# file_path="${DATA_MIX_DIR}/${subdir}/${dataset}"
+# if [ -f "$file_path" ]; then
+# echo "Adding: $file_path"
+# id_val_file_list+=("'$file_path'")
+# fi
+# done
+# done
+# id_val_file_list+=("'$iq400_path'")
+
+# Join with comma to form Python list string
+IFS=,
+train_files="[${train_file_list[*]}]"
+# id_val_files="[${id_val_file_list[*]}]"
+unset IFS
+
+echo "Total training files found: ${#train_file_list[@]}"
+# echo "Total ID validation files found: ${#id_val_file_list[@]}"
+
+# Test Data Configuration
+TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len
+# Math (test)
+math_test_path=${TEST_DATA_DIR}/math__math_500.parquet
+aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet
+aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet
+amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet
+
+# Code (test)
+humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet
+mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet
+livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet
+
+# Logic (test)
+zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet
+reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet
+synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet
+arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet
+# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet
+# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet
+
+# Table (test)
+multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet
+hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet
+
+# Stem (test)
+nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet
+gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet
+supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet
+
+# Instruction follow (test)
+if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet
+if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet
+
+# Focused data mixture (math, code, stem)
+# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']"
+# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']"
+
+# Full data mixture (uncomment to use)
+test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${synlogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # ,'${iq400_path}'
+
+
+# =================== Model ===================
+# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k)
+# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k
+BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406955/global_step_330/actor/huggingface
+
+# =================== Logging ===================
+WANDB_PROJECT=k2plus_rl
+WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/}
+# WANDB_EXPERIMENT_NAME="grpo-k2p-newFiltered-32k-mainQs-finalInstruct-406491"
+
+# If RESUME_CKPT_DIR is not empty, resume from the checkpoint
+if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then
+ WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME"
+fi
+
+
+# =================== Ray start ===================
+# ray stop at all nodes
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop
+
+sleep 10
+# Remove existing Ray cluster
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster
+
+# Start Ray head node
+srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block &
+
+sleep 10
+
+# Start Ray worker nodes
+for ((i = 1; i < worker_num; i++)); do
+ node_i=${nodes[$i]}
+ echo "Starting WORKER $i at $node_i"
+ srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --address "$address_head" \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block &
+done
+sleep 10
+
+
+# =================== RL Config ===================
+# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline.
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=0.28
+
+max_prompt_length=$((1024 * 4))
+max_response_length=$((1024 * 64))
+enable_overlong_buffer=False
+overlong_buffer_len=$((1024 * 12))
+overlong_penalty_factor=1.0
+
+loss_agg_mode="token-mean"
+rollout_dtype="float16"
+
+enable_filter_groups=False
+filter_groups_metric=acc
+max_num_gen_batches=10
+train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n
+gen_prompt_bsz=$((train_prompt_bsz * 1))
+n_resp_per_prompt=16
+train_prompt_mini_bsz=256 # model grad update batchsize
+
+# Algorithm
+temperature=1.2
+val_temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+
+# Training config
+sp_size=16 # Reduced from 32 to reduce memory pressure
+gen_tp=4
+gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure
+infer_micro_batch_size=null
+train_micro_batch_size=null
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow
+infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow
+offload=True
+
+# =================== Start RL training ===================
+"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \
+ --config-path=config \
+ --config-name="dapo_fsdp_config.yaml" \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ algorithm.filter_groups.enable=${enable_filter_groups} \
+ algorithm.filter_groups.metric=${filter_groups_metric} \
+ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
+ data.train_files="$train_files" \
+ data.val_files="$test_files" \
+ data.prompt_key=prompt \
+ data.truncation='right' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ data.gen_batch_size=${gen_prompt_bsz} \
+ actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \
+ actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
+ actor_rollout_ref.actor.strategy="fsdp2" \
+ actor_rollout_ref.actor.optim.lr=5e-7 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.optim.warmup_style=constant \
+ actor_rollout_ref.actor.optim.min_lr_ratio=0. \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
+ actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
+ actor_rollout_ref.actor.entropy_checkpointing=True \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \
+ actor_rollout_ref.rollout.disable_log_stats=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.enable_prefix_caching=True \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\
+ actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.model.path=$BASE_MODEL \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.rollout.multi_turn.enable=False \
+ actor_rollout_ref.rollout.mode="sync" \
+ +actor_rollout_ref.model.override_config.attention_dropout=0. \
+ +actor_rollout_ref.model.override_config.embd_pdrop=0. \
+ +actor_rollout_ref.model.override_config.resid_pdrop=0. \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.model.enable_activation_offload=True \
+ actor_rollout_ref.model.use_liger=True \
+ reward_model.reward_manager=async_multi_process \
+ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
+ reward_model.overlong_buffer.len=${overlong_buffer_len} \
+ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name=${WANDB_PROJECT} \
+ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=$worker_num \
+ trainer.save_freq=10 \
+ trainer.test_freq=5 \
+ trainer.total_epochs=5 \
+ trainer.log_val_generations=50 \
+ trainer.resume_mode=auto \
+ trainer.max_actor_ckpt_to_keep=3
+ # data.id_val_files="$id_val_files" \
\ No newline at end of file
diff --git a/scripts/train/test_k2p_cispo_m2.sh b/scripts/train/test_k2p_cispo_m2.sh
new file mode 100644
index 000000000..01cf6ae81
--- /dev/null
+++ b/scripts/train/test_k2p_cispo_m2.sh
@@ -0,0 +1,304 @@
+#!/bin/bash
+#SBATCH --job-name=cispo-focused-fixed
+#SBATCH --nodes=4
+#SBATCH --ntasks=4
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=96
+#SBATCH --account=iq
+#SBATCH --mem=0
+#SBATCH --output=slurm/%x-%j.log
+#SBATCH --error=slurm/%x-%j.log
+#SBATCH --exclusive
+#SBATCH --time=720:00:00
+#SBATCH --partition=main
+
+
+# =================== Frequently Used Variables ===================
+RESUME_CKPT_DIR_NAME="cispo-focused-fixed-Qwen2.5-7B-1016971" # Fill in the checkpoint directory name to resume from, otherwise from scratch
+export STEM_LLM_JUDGE_URL="http://10.24.2.1:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+
+# =================== Cluster Environment ===================
+export CONDA_BIN_PATH=/mnt/weka/home/taylor.killian/miniconda3/envs/sync-rl/bin/
+export NCCL_TIMEOUT_SECONDS=4800
+export TORCH_NCCL_ENABLE_MONITORING=0
+export NCCL_DEBUG=warn
+export NCCL_NET=IB
+export NCCL_IB_HCA="mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7"
+export NCCL_CROSS_NIC=1
+export NCCL_IB_TC=136
+export NCCL_SOCKET_IFNAME="^lo,docker,virbr"
+export CUDA_DEVICE_MAX_CONNECTIONS=8
+export NCCL_NVLS_ENABLE=1
+
+# Get the list of allocated nodes
+nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") )
+echo "Nodes to check: ${nodes[@]}"
+
+# We'll track PIDs so we can wait on them and detect errors
+declare -A pids
+export head_node=${nodes[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+port=6379
+address_head=$head_node_ip:$port
+
+export worker_num=$SLURM_NNODES
+export HYDRA_FULL_ERROR=1
+export VLLM_USE_V1=1
+
+# =================== Data Mixture ===================
+SHARED_DATA_PATH=/mnt/sharefs/users/zhuojun.cheng
+SHARED_MODEL_PATH=/mnt/sharefs/users/haonan.li/models
+TRAIN_DATA_DIR=${SHARED_DATA_PATH}/guru_data/train/guru92k_release_0603
+TEST_DATA_DIR=${SHARED_DATA_PATH}/guru_data/test/online # ← unchanged
+
+# ---------- Math ----------
+# train
+math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet
+# test
+math_test_path=${TEST_DATA_DIR}/math__math_500.parquet
+aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet
+amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet
+
+# ---------- Code ----------
+# train
+leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_1.3k.parquet
+livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_440.parquet
+primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_7.5k.parquet
+taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_8.8k.parquet
+# test (unchanged)
+humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet
+mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500_sampled_200.parquet
+livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet
+
+# ---------- Logic ----------
+# train
+arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet
+arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet
+barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet
+graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.2k.parquet
+ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_1.9k.parquet
+zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_1.3k.parquet
+# test (unchanged)
+zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_300_sampled_200.parquet
+graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet
+ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet
+arcagi1_test_path=${TEST_DATA_DIR}/simulation__arcagi1_200.parquet
+
+# ---------- Simulation ----------
+# train
+codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k_3.7k.parquet
+# test (unchanged)
+codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500_sampled_200.parquet
+
+# ---------- Table ----------
+# train
+hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet
+multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet
+# test (unchanged)
+multihier_test_path=${TEST_DATA_DIR}/table__multihier_300_sampled_200.parquet
+hitab_test_path=${TEST_DATA_DIR}/table__hitab_300_sampled_200.parquet
+
+# ---------- Stem ----------
+# train
+webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet
+# test (unchanged)
+gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet
+supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet
+
+# Full Guru92k mixture
+# train_files="['${math_train_path}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}']"
+# test_files="['${math_test_path}','${aime_test_path}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${graph_test_path}','${ordering_puzzle_test_path}','${arcagi1_test_path}','${codeio_test_path}','${multihier_test_path}','${hitab_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']"
+
+# Focused Guru92k mixture (Math + Code + STEM)
+train_files="['${math_train_path}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}']"
+test_files="['${math_test_path}','${aime_test_path}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']"
+
+# =================== Model ===================
+BASE_MODEL=Qwen/Qwen2.5-7B
+
+# =================== Logging ===================
+WANDB_PROJECT=k2plus_rl
+WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${BASE_MODEL##*/}-${SLURM_JOB_ID}
+
+# If RESUME_CKPT_DIR is not empty, resume from the checkpoint
+if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then
+ WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME"
+fi
+
+
+# =================== Ray start ===================
+# ray stop at all nodes
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop
+
+sleep 10
+# Remove existing Ray cluster
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster
+
+# Start Ray head node
+srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block &
+
+sleep 10
+
+# Start Ray worker nodes
+for ((i = 1; i < worker_num; i++)); do
+ node_i=${nodes[$i]}
+ echo "Starting WORKER $i at $node_i"
+ srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --address "$address_head" \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block &
+done
+sleep 10
+
+
+# =================== RL Config ===================
+# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline.
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=2.0
+
+max_prompt_length=$((1024 * 4))
+max_response_length=$((1024 * 32))
+enable_overlong_buffer=False
+overlong_buffer_len=$((1024 * 4))
+overlong_penalty_factor=1.0
+
+loss_mode="cispo" # Default is "vanilla" which is equivalent to PPO;
+loss_agg_mode="token-mean"
+rollout_dtype="float16"
+
+enable_filter_groups=False
+filter_groups_metric=acc
+max_num_gen_batches=10
+train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n
+gen_prompt_bsz=$((train_prompt_bsz * 1))
+n_resp_per_prompt=16
+train_prompt_mini_bsz=32 # model grad update batchsize
+
+# Algorithm
+temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+
+# Training config
+sp_size=1
+gen_tp=2
+gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure
+infer_micro_batch_size=null
+train_micro_batch_size=null
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow
+infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow
+offload=True
+
+# =================== Start RL training ===================
+"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \
+ --config-path=config \
+ --config-name="dapo_fsdp_config.yaml" \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ algorithm.filter_groups.enable=${enable_filter_groups} \
+ algorithm.filter_groups.metric=${filter_groups_metric} \
+ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
+ data.train_files="$train_files" \
+ data.val_files="$test_files" \
+ data.prompt_key=prompt \
+ data.truncation='right' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ data.gen_batch_size=${gen_prompt_bsz} \
+ actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \
+ actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
+ actor_rollout_ref.actor.policy_loss.cispo_clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.policy_loss.cispo_clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
+ actor_rollout_ref.actor.strategy="fsdp2" \
+ actor_rollout_ref.actor.dtype=${rollout_dtype} \
+ actor_rollout_ref.actor.optim.lr=5e-7 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.optim.warmup_style=constant \
+ actor_rollout_ref.actor.optim.min_lr_ratio=0. \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
+ actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
+ actor_rollout_ref.actor.entropy_checkpointing=True \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \
+ actor_rollout_ref.rollout.disable_log_stats=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.enable_prefix_caching=True \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\
+ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.model.path=$BASE_MODEL \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.rollout.multi_turn.enable=False \
+ actor_rollout_ref.rollout.mode="sync" \
+ actor_rollout_ref.rollout.dtype=${rollout_dtype} \
+ +actor_rollout_ref.model.override_config.attention_dropout=0. \
+ +actor_rollout_ref.model.override_config.embd_pdrop=0. \
+ +actor_rollout_ref.model.override_config.resid_pdrop=0. \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.model.enable_activation_offload=True \
+ reward_model.reward_manager=async_multi_process \
+ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
+ reward_model.overlong_buffer.len=${overlong_buffer_len} \
+ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name=${WANDB_PROJECT} \
+ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=$worker_num \
+ trainer.save_freq=10 \
+ trainer.test_freq=10 \
+ trainer.total_epochs=5 \
+ trainer.log_val_generations=50 \
+ trainer.resume_mode=auto \
+ trainer.max_actor_ckpt_to_keep=1
\ No newline at end of file
diff --git a/scripts/train/test_k2p_grpo_m2.sh b/scripts/train/test_k2p_grpo_m2.sh
new file mode 100644
index 000000000..aefbf34fc
--- /dev/null
+++ b/scripts/train/test_k2p_grpo_m2.sh
@@ -0,0 +1,304 @@
+#!/bin/bash
+#SBATCH --job-name=grpo-focused
+#SBATCH --nodes=4
+#SBATCH --ntasks=4
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:8
+#SBATCH --cpus-per-task=96
+#SBATCH --account=iq
+#SBATCH --mem=0
+#SBATCH --output=slurm/%x-%j.log
+#SBATCH --error=slurm/%x-%j.log
+#SBATCH --exclusive
+#SBATCH --time=720:00:00
+#SBATCH --partition=main
+
+
+# =================== Frequently Used Variables ===================
+RESUME_CKPT_DIR_NAME="grpo-focused-Qwen2.5-7B-994853" # Fill in the checkpoint directory name to resume from, otherwise from scratch
+export STEM_LLM_JUDGE_URL="http://10.24.2.1:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain
+
+# =================== Cluster Environment ===================
+export CONDA_BIN_PATH=/mnt/weka/home/taylor.killian/miniconda3/envs/sync-rl/bin/
+export NCCL_TIMEOUT_SECONDS=4800
+export TORCH_NCCL_ENABLE_MONITORING=0
+export NCCL_DEBUG=warn
+export NCCL_NET=IB
+export NCCL_IB_HCA="mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7"
+export NCCL_CROSS_NIC=1
+export NCCL_IB_TC=136
+export NCCL_SOCKET_IFNAME="^lo,docker,virbr"
+export CUDA_DEVICE_MAX_CONNECTIONS=8
+export NCCL_NVLS_ENABLE=1
+
+# Get the list of allocated nodes
+nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") )
+echo "Nodes to check: ${nodes[@]}"
+
+# We'll track PIDs so we can wait on them and detect errors
+declare -A pids
+export head_node=${nodes[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+port=6379
+address_head=$head_node_ip:$port
+
+export worker_num=$SLURM_NNODES
+export HYDRA_FULL_ERROR=1
+export VLLM_USE_V1=1
+
+# =================== Data Mixture ===================
+SHARED_DATA_PATH=/mnt/sharefs/users/zhuojun.cheng
+SHARED_MODEL_PATH=/mnt/sharefs/users/haonan.li/models
+TRAIN_DATA_DIR=${SHARED_DATA_PATH}/guru_data/train/guru92k_release_0603
+TEST_DATA_DIR=${SHARED_DATA_PATH}/guru_data/test/online # ← unchanged
+
+# ---------- Math ----------
+# train
+math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet
+# test
+math_test_path=${TEST_DATA_DIR}/math__math_500.parquet
+aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet
+amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet
+
+# ---------- Code ----------
+# train
+leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_1.3k.parquet
+livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_440.parquet
+primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_7.5k.parquet
+taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_8.8k.parquet
+# test (unchanged)
+humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet
+mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500_sampled_200.parquet
+livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet
+
+# ---------- Logic ----------
+# train
+arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet
+arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet
+barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet
+graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.2k.parquet
+ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_1.9k.parquet
+zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_1.3k.parquet
+# test (unchanged)
+zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_300_sampled_200.parquet
+graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet
+ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet
+arcagi1_test_path=${TEST_DATA_DIR}/simulation__arcagi1_200.parquet
+
+# ---------- Simulation ----------
+# train
+codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k_3.7k.parquet
+# test (unchanged)
+codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500_sampled_200.parquet
+
+# ---------- Table ----------
+# train
+hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet
+multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet
+# test (unchanged)
+multihier_test_path=${TEST_DATA_DIR}/table__multihier_300_sampled_200.parquet
+hitab_test_path=${TEST_DATA_DIR}/table__hitab_300_sampled_200.parquet
+
+# ---------- Stem ----------
+# train
+webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet
+# test (unchanged)
+gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet
+supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet
+
+# Full Guru92k mixture
+# train_files="['${math_train_path}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}']"
+# test_files="['${math_test_path}','${aime_test_path}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${graph_test_path}','${ordering_puzzle_test_path}','${arcagi1_test_path}','${codeio_test_path}','${multihier_test_path}','${hitab_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']"
+
+# Focused Guru92k mixture (Math + Code + STEM)
+train_files="['${math_train_path}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}']"
+test_files="['${math_test_path}','${aime_test_path}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']"
+
+# =================== Model ===================
+BASE_MODEL=Qwen/Qwen2.5-7B
+
+# =================== Logging ===================
+WANDB_PROJECT=k2plus_rl
+WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${BASE_MODEL##*/}-${SLURM_JOB_ID}
+
+# If RESUME_CKPT_DIR is not empty, resume from the checkpoint
+if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then
+ WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME"
+fi
+
+
+# =================== Ray start ===================
+# ray stop at all nodes
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop
+
+sleep 10
+# Remove existing Ray cluster
+srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster
+
+# Start Ray head node
+srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block &
+
+sleep 10
+
+# Start Ray worker nodes
+for ((i = 1; i < worker_num; i++)); do
+ node_i=${nodes[$i]}
+ echo "Starting WORKER $i at $node_i"
+ srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \
+ env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \
+ ${CONDA_BIN_PATH}ray start --address "$address_head" \
+ --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block &
+done
+sleep 10
+
+
+# =================== RL Config ===================
+# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline.
+
+adv_estimator=grpo
+
+use_kl_in_reward=False
+kl_coef=0.0
+use_kl_loss=False
+kl_loss_coef=0.0
+
+clip_ratio_low=0.2
+clip_ratio_high=0.2
+
+max_prompt_length=$((1024 * 4))
+max_response_length=$((1024 * 32))
+enable_overlong_buffer=False
+overlong_buffer_len=$((1024 * 4))
+overlong_penalty_factor=1.0
+
+loss_mode="vanilla" # Default is "vanilla" which is equivalent to PPO;
+loss_agg_mode="token-mean"
+rollout_dtype="float16"
+
+enable_filter_groups=False
+filter_groups_metric=acc
+max_num_gen_batches=10
+train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n
+gen_prompt_bsz=$((train_prompt_bsz * 1))
+n_resp_per_prompt=16
+train_prompt_mini_bsz=32 # model grad update batchsize
+
+# Algorithm
+temperature=1.0
+top_p=1.0
+top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
+
+# Training config
+sp_size=1
+gen_tp=2
+gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure
+infer_micro_batch_size=null
+train_micro_batch_size=null
+use_dynamic_bsz=True
+actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow
+infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow
+offload=True
+
+# =================== Start RL training ===================
+"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \
+ --config-path=config \
+ --config-name="dapo_fsdp_config.yaml" \
+ algorithm.adv_estimator=${adv_estimator} \
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
+ algorithm.filter_groups.enable=${enable_filter_groups} \
+ algorithm.filter_groups.metric=${filter_groups_metric} \
+ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
+ data.train_files="$train_files" \
+ data.val_files="$test_files" \
+ data.prompt_key=prompt \
+ data.truncation='right' \
+ data.max_prompt_length=${max_prompt_length} \
+ data.max_response_length=${max_response_length} \
+ data.train_batch_size=${train_prompt_bsz} \
+ data.gen_batch_size=${gen_prompt_bsz} \
+ actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \
+ actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
+ actor_rollout_ref.actor.policy_loss.cispo_clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.policy_loss.cispo_clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
+ actor_rollout_ref.actor.strategy="fsdp2" \
+ actor_rollout_ref.actor.dtype=${rollout_dtype} \
+ actor_rollout_ref.actor.optim.lr=5e-7 \
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
+ actor_rollout_ref.actor.optim.warmup_style=constant \
+ actor_rollout_ref.actor.optim.min_lr_ratio=0. \
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
+ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.grad_clip=1.0 \
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
+ actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
+ actor_rollout_ref.actor.entropy_checkpointing=True \
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
+ actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
+ actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \
+ actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \
+ actor_rollout_ref.rollout.disable_log_stats=False \
+ actor_rollout_ref.rollout.enforce_eager=False \
+ actor_rollout_ref.rollout.enable_prefix_caching=True \
+ actor_rollout_ref.rollout.temperature=${temperature} \
+ actor_rollout_ref.rollout.top_p=${top_p} \
+ actor_rollout_ref.rollout.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
+ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\
+ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
+ actor_rollout_ref.model.path=$BASE_MODEL \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.rollout.multi_turn.enable=False \
+ actor_rollout_ref.rollout.mode="sync" \
+ actor_rollout_ref.rollout.dtype=${rollout_dtype} \
+ +actor_rollout_ref.model.override_config.attention_dropout=0. \
+ +actor_rollout_ref.model.override_config.embd_pdrop=0. \
+ +actor_rollout_ref.model.override_config.resid_pdrop=0. \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.model.enable_activation_offload=True \
+ reward_model.reward_manager=async_multi_process \
+ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
+ reward_model.overlong_buffer.len=${overlong_buffer_len} \
+ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
+ trainer.logger=['console','wandb'] \
+ trainer.project_name=${WANDB_PROJECT} \
+ trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \
+ trainer.val_before_train=True \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=$worker_num \
+ trainer.save_freq=10 \
+ trainer.test_freq=10 \
+ trainer.total_epochs=5 \
+ trainer.log_val_generations=50 \
+ trainer.resume_mode=auto \
+ trainer.max_actor_ckpt_to_keep=1
\ No newline at end of file
diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml
index 444595c76..fe94e7a49 100644
--- a/verl/trainer/config/_generated_ppo_trainer.yaml
+++ b/verl/trainer/config/_generated_ppo_trainer.yaml
@@ -56,6 +56,8 @@ actor_rollout_ref:
clip_cov_ub: 5.0
kl_cov_ratio: 0.0002
ppo_kl_coef: 0.1
+ cispo_clip_ratio_high: 0.2
+ cispo_clip_ratio_low: 0.2
clip_ratio_c: 3.0
loss_agg_mode: token-mean
entropy_coeff: 0
diff --git a/verl/trainer/config/actor/actor.yaml b/verl/trainer/config/actor/actor.yaml
index 7c733ed60..e8c082e63 100644
--- a/verl/trainer/config/actor/actor.yaml
+++ b/verl/trainer/config/actor/actor.yaml
@@ -49,7 +49,7 @@ policy_loss:
# # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
_target_: verl.workers.config.PolicyLossConfig
- # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617
+ # Loss function mode: vanilla / clip-cov / kl-cov / cispo (from https://arxiv.org/abs/2506.13585) / gpg (from https://arxiv.org/abs/2505.22617)
loss_mode: "vanilla"
# Ratio of tokens to be clipped for clip-cov loss
@@ -67,6 +67,12 @@ policy_loss:
# KL divergence penalty coefficient
ppo_kl_coef: 0.1
+ # Upper bound for CISPO importance ratio clipping
+ cispo_clip_ratio_high: 0.2
+
+ # Lower bound for CISPO importance ratio clipping
+ cispo_clip_ratio_low: 0.2
+
# Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C
clip_ratio_c: 3.0
diff --git a/verl/trainer/config/data/legacy_data.yaml b/verl/trainer/config/data/legacy_data.yaml
index 028405b42..dde4c197c 100644
--- a/verl/trainer/config/data/legacy_data.yaml
+++ b/verl/trainer/config/data/legacy_data.yaml
@@ -10,6 +10,9 @@ use_shm: False
# For HDFS path, we provide utils to download it to DRAM and convert it to a local path.
train_files: ~/data/rlhf/gsm8k/train.parquet
+# ID Validation set parquet. Can be a list or a single file.
+# id_val_files: null
+
# Validation parquet. Can be a list or a single file.
val_files: ~/data/rlhf/gsm8k/test.parquet
diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py
index 7a9103c4d..425f095b0 100644
--- a/verl/trainer/ppo/core_algos.py
+++ b/verl/trainer/ppo/core_algos.py
@@ -1164,6 +1164,77 @@ def compute_policy_loss_kl_cov(
return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)
+@register_policy_loss("cispo")
+def compute_policy_loss_cispo(
+ old_log_prob: torch.Tensor,
+ log_prob: torch.Tensor,
+ advantages: torch.Tensor,
+ response_mask: torch.Tensor,
+ loss_agg_mode: str = "token-mean",
+ config: Optional[DictConfig | AlgoConfig] = None,
+ rollout_log_probs: torch.Tensor | None = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Compute the CISPO policy objective and related metrics.
+ CISPO (Clipped Importance Sampling Policy Optimization) clips importance sampling weights
+ instead of dropping tokens, which is beneficial for training on sparse but critical tokens
+ and long-context reasoning in RL.
+ Reference: https://www.arxiv.org/abs/2506.13585
+ Args:
+ old_log_prob (torch.Tensor):
+ Log-probabilities of actions under the old policy, shape (batch_size, response_length).
+ log_prob (torch.Tensor):
+ Log-probabilities of actions under the current policy, shape (batch_size, response_length).
+ advantages (torch.Tensor):
+ Advantage estimates for each action, shape (batch_size, response_length).
+ response_mask (torch.Tensor):
+ Mask indicating which tokens to include in the loss, shape (batch_size, response_length).
+ loss_agg_mode (str, optional):
+ Aggregation mode for loss computation
+ config (AlgoConfig):
+ Algorithm configuration containing CISPO parameters
+ rollout_log_probs: `(torch.Tensor)`:
+ log probabilities of actions under the rollout policy, shape (batch_size, response_length).
+ Returns:
+ tuple: (pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower)
+ """
+ # Setup CISPO configuration
+ assert config.policy_loss.loss_mode == "cispo", "CISPO loss mode not set in config"
+ cispo_clip_ratio_high = config.policy_loss.cispo_clip_ratio_high
+ cispo_clip_ratio_low = config.policy_loss.cispo_clip_ratio_low
+ clip_ratio_c = config.get("clip_ratio_c", 3.0)
+
+ # Same code as compute_policy_loss
+ assert clip_ratio_c > 1.0, (
+ "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0,"
+ + f" but get the value: {clip_ratio_c}."
+ )
+
+ negative_approx_kl = log_prob - old_log_prob
+ # Clamp negative_approx_kl for stability
+ negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)
+ ratio = torch.exp(negative_approx_kl)
+ ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
+
+ # CISPO specific loss
+ ratio = ratio.detach() # Stop gradient on IS ratio
+ importance_sampling_weight = torch.clamp(ratio, max=1+cispo_clip_ratio_high)
+ pg_losses = -advantages * log_prob * importance_sampling_weight
+
+ if config.tis_imp_ratio_cap > 0 and rollout_log_probs is not None:
+ # Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl
+ tis_imp_ratio = torch.exp(old_log_prob - rollout_log_probs)
+ tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap)
+ pg_losses = pg_losses * tis_imp_ratio
+
+ pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
+ # For compatibility, return zero for pg_clipfrac_lower and pg_clipfrac (not used in CISPO)
+ pg_clipfrac = torch.tensor(0.0, device=pg_loss.device)
+ pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device)
+
+ return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
+
+
@register_policy_loss("geo_mean")
def compute_policy_loss_geo_mean(
old_log_prob: torch.Tensor,
diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py
index e8d3e6a2c..ae31a481a 100644
--- a/verl/trainer/ppo/ray_trainer.py
+++ b/verl/trainer/ppo/ray_trainer.py
@@ -280,6 +280,7 @@ def __init__(
reward_fn=None,
val_reward_fn=None,
train_dataset: Optional[Dataset] = None,
+ # id_val_dataset: Optional[Dataset] = None,
val_dataset: Optional[Dataset] = None,
collate_fn=None,
train_sampler: Optional[Sampler] = None,
@@ -299,6 +300,7 @@ def __init__(
reward_fn: Function for computing rewards during training.
val_reward_fn: Function for computing rewards during validation.
train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None.
+ # id_val_dataset (Optional[Dataset], optional): id Validation dataset. Defaults to None.
val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None.
collate_fn: Function to collate data samples into batches.
train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None.
@@ -338,11 +340,12 @@ def __init__(
if self.config.algorithm.use_kl_in_reward:
self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)
- self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
+ self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) # id_val_dataset
- def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]):
+ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): # id_val_dataset
"""
Creates the train and validation dataloaders.
+ # Added by Reasoning 360: creates a third dataloader for id validation... To be kept separate from standard validation approach.
"""
# TODO: we have to make sure the batch size is divisible by the dp size
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
@@ -351,11 +354,16 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
train_dataset = create_rl_dataset(
self.config.data.train_files, self.config.data, self.tokenizer, self.processor
)
+ # if id_val_dataset is None and self.config.data.id_val_files is not None:
+ # id_val_dataset = create_rl_dataset(
+ # self.config.data.id_val_files, self.config.data, self.tokenizer, self.processor
+ # )
if val_dataset is None:
val_dataset = create_rl_dataset(
self.config.data.val_files, self.config.data, self.tokenizer, self.processor
)
- self.train_dataset, self.val_dataset = train_dataset, val_dataset
+ self.train_dataset, self.val_dataset = train_dataset, val_dataset
+ # self.id_val_dataset = id_val_dataset
if train_sampler is None:
train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
@@ -376,9 +384,23 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
)
val_batch_size = self.config.data.val_batch_size # Prefer config value if set
+ # id_val_batch_size = self.config.data.val_batch_size
if val_batch_size is None:
val_batch_size = len(self.val_dataset)
-
+ # id_val_batch_size = len(self.id_val_dataset) if self.id_val_dataset is not None else 0
+
+ # if self.id_val_dataset is not None:
+ # self.id_val_dataloader = StatefulDataLoader(
+ # dataset=self.id_val_dataset,
+ # batch_size=id_val_batch_size,
+ # num_workers=num_workers,
+ # shuffle=self.config.data.get("id_validation_shuffle", True),
+ # drop_last=False,
+ # collate_fn=collate_fn,
+ # )
+ # else:
+ # self.id_val_dataloader = None
+
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
batch_size=val_batch_size,
@@ -389,12 +411,13 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
)
assert len(self.train_dataloader) >= 1, "Train dataloader is empty!"
+ # assert self.id_val_dataloader is None or len(self.id_val_dataloader) >= 1, "id Validation dataloader is empty!"
assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!"
print(
f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: "
f"{len(self.val_dataloader)}"
- )
+ ) # Size of id val dataloader: {len(self.id_val_dataloader)}
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
@@ -485,9 +508,12 @@ def _get_gen_batch(self, batch: DataProto) -> DataProto:
def _validate(self):
data_source_lst = []
+ # id_data_source_lst = []
reward_extra_infos_dict: dict[str, list] = defaultdict(list)
+ # id_reward_extra_infos_dict: dict[str, list] = defaultdict(list)
# NOTE: added by Reasoning360.
dataset_lst = []
+ # dataset_id_lst = []
# Lists to collect samples for the table
sample_inputs = []
@@ -497,6 +523,112 @@ def _validate(self):
sample_turns = []
sample_uids = []
+ # sample_id_inputs = []
+ # sample_id_outputs = []
+ # sample_id_gts = []
+ # sample_id_scores = []
+ # sample_id_turns = []
+ # sample_id_uids = []
+
+ # if self.id_val_dataloader is not None:
+ # print("Starting id validation generation...")
+ # for test_data in self.id_val_dataloader:
+ # test_batch = DataProto.from_single_dict(test_data)
+
+ # if "uid" not in test_batch.non_tensor_batch:
+ # test_batch.non_tensor_batch["uid"] = np.array(
+ # [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object
+ # )
+
+ # # repeat test batch
+ # test_batch = test_batch.repeat(
+ # repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
+ # )
+
+ # # we only do validation on rule-based rm
+ # if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
+ # return {}
+
+ # # Store original inputs
+ # input_ids = test_batch.batch["input_ids"]
+ # # TODO: Can we keep special tokens except for padding tokens?
+ # input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
+ # sample_id_inputs.extend(input_texts)
+ # sample_id_uids.extend(test_batch.non_tensor_batch["uid"])
+
+ # ground_truths = [
+ # item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch
+ # ]
+ # sample_id_gts.extend(ground_truths)
+
+ # test_gen_batch = self._get_gen_batch(test_batch)
+ # test_gen_batch.meta_info = {
+ # "eos_token_id": self.tokenizer.eos_token_id,
+ # "pad_token_id": self.tokenizer.pad_token_id,
+ # "recompute_log_prob": False,
+ # "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
+ # "validate": True,
+ # "global_steps": self.global_steps,
+ # }
+ # print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")
+
+ # # pad to be divisible by dp_size
+ # size_divisor = (
+ # self.actor_rollout_wg.world_size
+ # if not self.async_rollout_mode
+ # else self.config.actor_rollout_ref.rollout.agent.num_workers
+ # )
+ # test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)
+ # if not self.async_rollout_mode:
+ # test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
+ # else:
+ # test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
+
+ # # unpad
+ # test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
+
+ # print("ID Validation generation end")
+
+ # # Store generated outputs
+ # output_ids = test_output_gen_batch.batch["responses"]
+ # output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
+ # sample_id_outputs.extend(output_texts)
+
+ # test_batch = test_batch.union(test_output_gen_batch)
+ # test_batch.meta_info["validate"] = True
+
+ # # evaluate using reward_function
+ # if self.val_reward_fn is None:
+ # raise ValueError("val_reward_fn must be provided for validation.")
+ # result = self.val_reward_fn(test_batch, return_dict=True)
+ # reward_tensor = result["reward_tensor"]
+ # scores = reward_tensor.sum(-1).cpu().tolist()
+ # sample_id_scores.extend(scores)
+
+ # id_reward_extra_infos_dict["reward"].extend(scores)
+ # print(f"len id_reward_extra_infos_dict['reward']: {len(id_reward_extra_infos_dict['reward'])}")
+ # if "reward_extra_info" in result:
+ # for key, lst in result["reward_extra_info"].items():
+ # id_reward_extra_infos_dict[key].extend(lst)
+ # print(f"len id_reward_extra_infos_dict['{key}']: {len(id_reward_extra_infos_dict[key])}")
+
+ # # NOTE: added by Reasoning360. Collect dataset information. TODO: maybe replicated usage with the data_source_lst and can be removed?
+ # datasets = []
+ # for i in range(reward_tensor.shape[0]):
+ # dataset = "unknown"
+ # if "extra_info" in test_batch.non_tensor_batch:
+ # extra_info = test_batch.non_tensor_batch["extra_info"][i]
+ # if isinstance(extra_info, dict) and "dataset" in extra_info:
+ # dataset = extra_info["dataset"]
+ # datasets.append(dataset)
+ # dataset_id_lst.append(np.array(datasets))
+
+ # # collect num_turns of each prompt
+ # if "__num_turns__" in test_batch.non_tensor_batch:
+ # sample_id_turns.append(test_batch.non_tensor_batch["__num_turns__"])
+
+ # id_data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))
+
for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)
@@ -608,13 +740,37 @@ def _validate(self):
dump_path=val_data_dir,
)
+ # for key_info, lst in id_reward_extra_infos_dict.items():
+ # assert len(lst) == 0 or len(lst) == len(sample_id_scores), f"{key_info}: {len(lst)=}, {len(sample_id_scores)=}"
+
for key_info, lst in reward_extra_infos_dict.items():
assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"
# NOTE: Added by Reasoning360: Calculate the mean reward for each data source and dataset
+ # id_data_sources = np.concatenate(id_data_source_lst, axis=0)
+ # id_datasets = np.concatenate(dataset_id_lst, axis=0) # Concatenate datasets
data_sources = np.concatenate(data_source_lst, axis=0)
datasets = np.concatenate(dataset_lst, axis=0) # Concatenate datasets
+ # id_data_src2var2metric2val = process_validation_metrics(id_data_sources, sample_id_uids, id_reward_extra_infos_dict)
+ # id_metric_dict = {}
+ # for data_source, var2metric2val in id_data_src2var2metric2val.items():
+ # core_var = "acc" if "acc" in var2metric2val else "reward"
+ # for var_name, metric2val in var2metric2val.items():
+ # n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
+ # for metric_name, metric_val in metric2val.items():
+ # # NOTE: added by Reasoning360. Add std metrics
+ # if (
+ # (var_name == core_var)
+ # and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best", "std"])
+ # and (f"@{n_max}" in metric_name)
+ # ):
+ # metric_sec = "InDomain-Eval-core"
+ # else:
+ # metric_sec = "InDomain-Eval-aux"
+ # pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
+ # id_metric_dict[pfx] = metric_val
+
data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict)
metric_dict = {}
for data_source, var2metric2val in data_src2var2metric2val.items():
@@ -654,7 +810,7 @@ def _validate(self):
for (data_source, dataset), rewards in data_source_dataset_reward.items():
metric_dict[f"val/test_score/{data_source}/{dataset}"] = np.mean(rewards)
- return metric_dict
+ return metric_dict # id_metric_dict | metric_dict # Union of two dicts
def init_workers(self):
"""Initialize distributed training workers using Ray backend.
diff --git a/verl/utils/py_functional.py b/verl/utils/py_functional.py
index 159c25890..b09ec3dda 100644
--- a/verl/utils/py_functional.py
+++ b/verl/utils/py_functional.py
@@ -111,8 +111,19 @@ def wrapper_mp(*args, **kwargs):
if process.is_alive():
process.terminate()
process.join(timeout=0.5) # Give it a moment to terminate
+ if process.is_alive():
+ try:
+ process.kill()
+ print(f"Warning: Escalated to force kill process {process.pid}")
+ except AttributeError:
+ os.kill(process.pid, signal.SIGKILL)
+ print(f"Fall back to very old way of killing process")
+ process.join(timeout=0.5)
+
if process.is_alive():
print(f"Warning: Process {process.pid} did not terminate gracefully after timeout.")
+ raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds and could not be killed (pid={process.pid})!")
+
# Update function name in error message if needed (optional but good practice)
raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds (multiprocessing)!")
diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py
index bb014b13d..a1b07cd58 100644
--- a/verl/utils/reward_score/__init__.py
+++ b/verl/utils/reward_score/__init__.py
@@ -13,6 +13,8 @@
# limitations under the License.
# from . import gsm8k, math, prime_math, prime_code
+import time
+
from verl.utils.import_utils import deprecated
@@ -40,6 +42,8 @@ def default_compute_score(
Raises:
NotImplementedError: If the reward function is not implemented for the given data source.
"""
+ start_time = time.perf_counter()
+
# Handle extra_info format robustly
reward_metric = None
if extra_info and isinstance(extra_info, dict):
@@ -209,12 +213,28 @@ def default_compute_score(
else:
raise NotImplementedError(f"Reward function is not implemented for {data_source=}")
+ # Calculate elapsed time
+ elapsed = time.perf_counter() - start_time
+
+ # Log slow samples (>1 second)
+ if elapsed > 10.0:
+ print(f"[SLOW REWARD] {data_source}: {elapsed:.2f}s")
+
+ # Return result with timing metadata
+ # Always ensure "score" and "acc" keys are present for consistency
if isinstance(res, dict):
+ res["_reward_time"] = elapsed
+ res["_data_source"] = data_source
+ # Ensure "acc" is present if not already
+ if "acc" not in res:
+ res["acc"] = res.get("score", 0.0)
return res
elif isinstance(res, int | float | bool):
- return float(res)
+ score = float(res)
+ return {"score": score, "acc": score, "_reward_time": elapsed, "_data_source": data_source}
else:
- return float(res[0])
+ score = float(res[0])
+ return {"score": score, "acc": score, "_reward_time": elapsed, "_data_source": data_source}
@deprecated("verl.utils.reward_score.default_compute_score")
diff --git a/verl/utils/reward_score/arcagi.py b/verl/utils/reward_score/arcagi.py
index b8367d890..b30e0d21e 100644
--- a/verl/utils/reward_score/arcagi.py
+++ b/verl/utils/reward_score/arcagi.py
@@ -2,8 +2,9 @@
import ast
import numpy as np
+
def extract_solution(solution_str):
- answer_pattern = r'(.*?)'
+ answer_pattern = r"(.*?)"
matches = list(re.finditer(answer_pattern, solution_str, flags=re.DOTALL))
if matches:
final_answer = matches[-1].group(1).strip()
@@ -11,8 +12,8 @@ def extract_solution(solution_str):
final_answer = final_answer.replace("...", "-1")
try:
# Find the part of the text that looks like a nested list
- start = final_answer.index('[[')
- end = final_answer.index(']]', start) + 2
+ start = final_answer.index("[[")
+ end = final_answer.index("]]", start) + 2
array_str = final_answer[start:end]
# Use ast.literal_eval to safely evaluate the string as a Python expression
array = ast.literal_eval(array_str)
@@ -25,6 +26,7 @@ def extract_solution(solution_str):
else:
return [[0]]
+
def pad_array_with_value(array, target_shape, pad_value):
"""
Pad the given array to the target shape with the specified pad value.
@@ -53,7 +55,7 @@ def pad_array_with_value(array, target_shape, pad_value):
except Exception as e:
array = np.array([[0]])
original_shape = array.shape
- padded_array[:original_shape[0], :original_shape[1]] = array
+ padded_array[: original_shape[0], : original_shape[1]] = array
return padded_array
@@ -74,19 +76,30 @@ def compare_solutions_with_padding(generated_output, correct_output, pad_value=-
max_rows = max(len(generated_output), len(correct_output))
max_cols = max(len(generated_output[0]), len(correct_output[0]))
target_shape = (max_rows, max_cols)
-
+
padded_generated = pad_array_with_value(generated_output, target_shape, pad_value)
padded_correct = pad_array_with_value(correct_output, target_shape, pad_value)
total_pixels = max_rows * max_cols
- correct_pixels = np.sum((padded_generated == padded_correct) & (padded_generated != pad_value) & (padded_correct != pad_value))
- correct_percentage = (correct_pixels / total_pixels)
+ correct_pixels = np.sum(
+ (padded_generated == padded_correct)
+ & (padded_generated != pad_value)
+ & (padded_correct != pad_value)
+ )
+ correct_percentage = correct_pixels / total_pixels
is_correct = float(correct_pixels == total_pixels)
return is_correct, correct_percentage
-
-def compute_score(model_output: str, ground_truth: np.ndarray, extra_info: any = None) -> float:
- model_output = str(model_output)
- final_answer = extract_solution(model_output)
- is_correct, correct_percentage = compare_solutions_with_padding(final_answer, ground_truth)
- return {"score": is_correct, "acc": is_correct}
\ No newline at end of file
+def compute_score(
+ model_output: str, ground_truth: np.ndarray, extra_info: any = None
+) -> float:
+ try:
+ model_output_str = str(model_output)
+ final_answer = extract_solution(model_output_str)
+ is_correct, correct_percentage = compare_solutions_with_padding(
+ final_answer, ground_truth
+ )
+ return {"score": is_correct, "acc": is_correct}
+ except Exception as e:
+ print(f"Error in compute_score in arcagi: {e}")
+ return {"score": 0.0, "acc": 0.0}
diff --git a/verl/utils/reward_score/codeio.py b/verl/utils/reward_score/codeio.py
index 44955ae42..5a84c719f 100644
--- a/verl/utils/reward_score/codeio.py
+++ b/verl/utils/reward_score/codeio.py
@@ -139,8 +139,13 @@ def compute_score(model_output: str, ground_truth: str, extra_info: any = None)
"""
Compute score dict for evaluation harness.
"""
- correct, _ = check_accuracy(str(model_output), str(ground_truth), any_order=False)
- return {"score": correct, "acc": correct}
+
+ try:
+ correct, _ = check_accuracy(str(model_output), str(ground_truth), any_order=False)
+ return {"score": correct, "acc": correct}
+ except Exception as e:
+ print(f"Error in compute_score in codeio: {e}")
+ return {"score": False, "acc": False}
# --------------------------- test --------------------------- #
diff --git a/verl/utils/reward_score/cruxeval/utils.py b/verl/utils/reward_score/cruxeval/utils.py
index 9eec77cd6..f554152c6 100644
--- a/verl/utils/reward_score/cruxeval/utils.py
+++ b/verl/utils/reward_score/cruxeval/utils.py
@@ -9,9 +9,10 @@
import multiprocessing
import os
import platform
-import signal
import tempfile
+from verl.utils.py_functional import timeout_limit
+
def check_correctness(check_program, timeout=3):
"""
@@ -53,12 +54,15 @@ def unsafe_execute(check_program, result, timeout):
# Run program.
try:
- exec_globals = {}
- with swallow_io():
- with time_limit(timeout):
+ @timeout_limit(seconds=timeout)
+ def _exec_with_timeout():
+ exec_globals = {}
+ with swallow_io():
exec(check_program, exec_globals)
+
+ _exec_with_timeout()
result.append("passed")
- except TimeoutException:
+ except TimeoutError:
result.append("timed out")
except BaseException as e:
result.append(f"failed: {e}")
@@ -69,19 +73,6 @@ def unsafe_execute(check_program, result, timeout):
os.chdir = chdir
-@contextlib.contextmanager
-def time_limit(seconds):
- def signal_handler(signum, frame):
- raise TimeoutException("Timed out!")
-
- signal.setitimer(signal.ITIMER_REAL, seconds)
- signal.signal(signal.SIGALRM, signal_handler)
- try:
- yield
- finally:
- signal.setitimer(signal.ITIMER_REAL, 0)
-
-
@contextlib.contextmanager
def swallow_io():
stream = WriteOnlyStringIO()
@@ -98,10 +89,6 @@ def create_tempdir():
yield dirname
-class TimeoutException(Exception):
- pass
-
-
class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from"""
diff --git a/verl/utils/reward_score/graph_dataset.py b/verl/utils/reward_score/graph_dataset.py
index 3bff2554a..51983e330 100644
--- a/verl/utils/reward_score/graph_dataset.py
+++ b/verl/utils/reward_score/graph_dataset.py
@@ -2,70 +2,57 @@
import random
import ast
import operator
-import signal
-import contextlib
-class TimeoutException(Exception):
- pass
-
-@contextlib.contextmanager
-def time_limit(seconds: float):
- def signal_handler(signum, frame):
- raise TimeoutException("Timed out!")
-
- signal.setitimer(signal.ITIMER_REAL, seconds)
- signal.signal(signal.SIGALRM, signal_handler)
- try:
- yield
- finally:
- signal.setitimer(signal.ITIMER_REAL, 0)
def extract_solution(solution_str):
- answer_pattern = r'(.*?)'
+ answer_pattern = r"(.*?)"
match = re.finditer(answer_pattern, solution_str)
matches = list(match)
if matches:
final_answer = matches[-1].group(1).strip()
- if re.search(r'^[A-Za-z]+$', final_answer):
+ if re.search(r"^[A-Za-z]+$", final_answer):
return final_answer
else:
return None
-def compute_score(solution_str, ground_truth, extra_info: any = None, timeout: float = 10.0):
+
+def compute_score(
+ solution_str, ground_truth, extra_info: any = None, timeout: float = 10.0
+):
"""The scoring function for graph dataset task.
-
+
Args:
solution_str: the solution text
ground_truth: the correct answer
timeout: maximum time in seconds to allow for computation
"""
- try:
- with time_limit(timeout):
- if not isinstance(ground_truth, str):
- ground_truth = str(ground_truth)
- target = ground_truth.lower()
- solution = extract_solution(solution_str)
-
- if solution:
- solution = solution.lower()
- else:
- score = 0.0
+ def _compute_with_timeout():
+ if not isinstance(ground_truth, str):
+ ground_truth_str = str(ground_truth)
+ else:
+ ground_truth_str = ground_truth
- try:
- if target == solution:
- score = 1.0
- else:
- score = 0.0
+ target = ground_truth_str.lower()
+ solution = extract_solution(solution_str)
- except Exception as e:
- score = 0.0
+ if solution:
+ solution = solution.lower()
+ else:
+ return 0.0
- except TimeoutException:
- print("Computation timed out in graph_dataset")
- score = 0.0
+ try:
+ if target == solution:
+ return 1.0
+ else:
+ return 0.0
+ except Exception as e:
+ return 0.0
+
+ try:
+ score = _compute_with_timeout()
except Exception as e:
print(f"Error in compute_score in graph_dataset: {e}")
score = 0.0
diff --git a/verl/utils/reward_score/ifbench/__init__.py b/verl/utils/reward_score/ifbench/__init__.py
index 5d53be4dc..12f07b962 100644
--- a/verl/utils/reward_score/ifbench/__init__.py
+++ b/verl/utils/reward_score/ifbench/__init__.py
@@ -18,50 +18,55 @@ def compute_score(solution_str, ground_truth, extra_info=None):
Returns:
dict: {"score": float, "acc": bool}
"""
- # Strip off any thinking section
- if "" in solution_str:
- answer = solution_str.split("", 1)[1].strip()
- else:
- answer = solution_str.strip()
- # Parse ground_truth if it's a string
- if isinstance(ground_truth, str):
- try:
- gt_list = ast.literal_eval(ground_truth)
- except Exception:
- gt_list = json.loads(ground_truth)
- else:
- gt_list = ground_truth
+ try:
+ # Strip off any thinking section
+ if "" in solution_str:
+ answer = solution_str.split("", 1)[1].strip()
+ else:
+ answer = solution_str.strip()
- # Take the first set of constraints
- if not isinstance(gt_list, list) or not gt_list:
- return {"score": 0.0, "acc": False}
- first_item = gt_list[0]
- instruction_ids = first_item.get("instruction_id", [])
- kwargs_list = first_item.get("kwargs", [])
+ # Parse ground_truth if it's a string
+ if isinstance(ground_truth, str):
+ try:
+ gt_list = ast.literal_eval(ground_truth)
+ except Exception:
+ gt_list = json.loads(ground_truth)
+ else:
+ gt_list = ground_truth
+
+ # Take the first set of constraints
+ if not isinstance(gt_list, list) or not gt_list:
+ return {"score": 0.0, "acc": False}
+ first_item = gt_list[0]
+ instruction_ids = first_item.get("instruction_id", [])
+ kwargs_list = first_item.get("kwargs", [])
- # Evaluate each instruction
- results = []
- for instr_id, raw_args in zip(instruction_ids, kwargs_list):
- # Prepare args dict
- args = {} if raw_args is None else raw_args
- # Convert numpy and floats
- clean_args = {}
- for key, val in args.items():
- if isinstance(val, float):
- clean_args[key] = int(val)
- elif isinstance(val, np.ndarray):
- clean_args[key] = val.tolist()
- else:
- clean_args[key] = val
+ # Evaluate each instruction
+ results = []
+ for instr_id, raw_args in zip(instruction_ids, kwargs_list):
+ # Prepare args dict
+ args = {} if raw_args is None else raw_args
+ # Convert numpy and floats
+ clean_args = {}
+ for key, val in args.items():
+ if isinstance(val, float):
+ clean_args[key] = int(val)
+ elif isinstance(val, np.ndarray):
+ clean_args[key] = val.tolist()
+ else:
+ clean_args[key] = val
- # Build and check instruction
- instr_cls = INSTRUCTION_DICT[instr_id]
- instr = instr_cls(instr_id)
- instr.build_description(**clean_args)
- passed = bool(answer and instr.check_following(answer))
- results.append(passed)
+ # Build and check instruction
+ instr_cls = INSTRUCTION_DICT[instr_id]
+ instr = instr_cls(instr_id)
+ instr.build_description(**clean_args)
+ passed = bool(answer and instr.check_following(answer))
+ results.append(passed)
- # Return 1.0 if all constraints are satisfied, 0.0 otherwise
- score = 1.0 if all(results) else 0.0
- return {"score": score, "acc": score == 1.0}
+ # Return 1.0 if all constraints are satisfied, 0.0 otherwise
+ score = 1.0 if all(results) else 0.0
+ return {"score": score, "acc": score == 1.0}
+ except Exception as e:
+ print(f"Error in compute_score in ifbench: {e}")
+ return {"score": 0.0, "acc": False}
diff --git a/verl/utils/reward_score/ifeval/__init__.py b/verl/utils/reward_score/ifeval/__init__.py
index 3f6c312fb..aaf4dceac 100644
--- a/verl/utils/reward_score/ifeval/__init__.py
+++ b/verl/utils/reward_score/ifeval/__init__.py
@@ -1,5 +1,6 @@
from verl.utils.reward_score.ifeval import instructions_registry
import numpy as np
+
def compute_score(solution_str, ground_truth, extra_info):
"""The scoring function for IFEval.
@@ -12,36 +13,40 @@ def compute_score(solution_str, ground_truth, extra_info):
format_score: the score for the format
score: the score for the correct answer
"""
- if "" in solution_str:
- answer = solution_str.split("")[1]
- else:
- answer = solution_str
- is_following_list = []
- for index, instruction_id in enumerate(extra_info["instruction_id_list"]):
- instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
- instruction = instruction_cls(instruction_id)
+ try:
+ if "" in solution_str:
+ answer = solution_str.split("")[1]
+ else:
+ answer = solution_str
+ is_following_list = []
+ for index, instruction_id in enumerate(extra_info["instruction_id_list"]):
+ instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
+ instruction = instruction_cls(instruction_id)
- # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
- # print(ground_truth)
- # print(extra_info["instruction_id_list"])
- # print(len(extra_info["instruction_id_list"]))
- # if v is , turn it into a list
- # if v is float, turn it into an int
- kwargs = {
- k: int(v) if isinstance(v, float) else v.tolist() if isinstance(v, np.ndarray) else v for k, v in ground_truth[index].items() if v is not None
- }
+ # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
+ # print(ground_truth)
+ # print(extra_info["instruction_id_list"])
+ # print(len(extra_info["instruction_id_list"]))
+ # if v is , turn it into a list
+ # if v is float, turn it into an int
+ kwargs = {
+ k: int(v) if isinstance(v, float) else v.tolist() if isinstance(v, np.ndarray) else v for k, v in ground_truth[index].items() if v is not None
+ }
- instruction.build_description(**kwargs)
- args = instruction.get_instruction_args()
- if args and "prompt" in args:
- instruction.build_description(prompt=extra_info["prompt"])
+ instruction.build_description(**kwargs)
+ args = instruction.get_instruction_args()
+ if args and "prompt" in args:
+ instruction.build_description(prompt=extra_info["prompt"])
- if answer.strip() and instruction.check_following(answer):
- is_following_list.append(True)
- else:
- is_following_list.append(False)
+ if answer.strip() and instruction.check_following(answer):
+ is_following_list.append(True)
+ else:
+ is_following_list.append(False)
- return {
- "score": all(is_following_list),
- "acc": all(is_following_list),
- }
+ return {
+ "score": all(is_following_list),
+ "acc": all(is_following_list),
+ }
+ except Exception as e:
+ print(f"Error in compute_score in ifbench: {e}")
+ return {"score": 0.0, "acc": False}
\ No newline at end of file
diff --git a/verl/utils/reward_score/math_llm_judge/__init__.py b/verl/utils/reward_score/math_llm_judge/__init__.py
index 7d2b20cae..f252aae80 100644
--- a/verl/utils/reward_score/math_llm_judge/__init__.py
+++ b/verl/utils/reward_score/math_llm_judge/__init__.py
@@ -36,15 +36,18 @@
"""
import re
+import os
+import math
+
import sympy
-from pylatexenc import latex2text
+# from pylatexenc import latex2text
from sympy.parsing import sympy_parser
-import os
+import requests
+from verl.utils.py_functional import timeout_limit
+
from . import math_normalize
from .grader import math_equal
-import requests
-
# import math_normalize
# from grader import math_equal
@@ -54,33 +57,6 @@
TUPLE_CHARS = "()[]"
-def timeout(timeout_seconds: int = 8):
- if os.name == "posix":
- import signal
-
- def decorator(func):
-
- def handler(signum, frame):
- raise TimeoutError("Operation timed out!")
-
- def wrapper(*args, **kwargs):
- old_handler = signal.getsignal(signal.SIGALRM)
- signal.signal(signal.SIGALRM, handler)
- signal.alarm(timeout_seconds)
-
- try:
- return func(*args, **kwargs)
- finally:
- signal.alarm(0)
- signal.signal(signal.SIGALRM, old_handler)
-
- return wrapper
-
- return decorator
- else:
- raise NotImplementedError(f"Unsupported OS: {os.name}")
-
-
def _sympy_parse(expr: str):
"""Parses an expression with sympy."""
py_expr = expr.replace("^", "**")
@@ -95,7 +71,7 @@ def _parse_latex(expr: str) -> str:
expr = expr.replace("\\tfrac", "\\frac")
expr = expr.replace("\\dfrac", "\\frac")
expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers.
- expr = latex2text.LatexNodes2Text().latex_to_text(expr)
+ # expr = latex2text.LatexNodes2Text().latex_to_text(expr)
# Replace the specific characters that this parser uses.
expr = expr.replace("√", "sqrt")
@@ -255,7 +231,7 @@ def should_allow_eval(expr: str):
return True
-@timeout(timeout_seconds=10)
+@timeout_limit(seconds=10)
def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
are_equal = False
try:
@@ -332,7 +308,10 @@ def grade_answer(given_answer: str, ground_truth: str) -> bool:
# if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify)
is_correct = False
else:
- is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
+ try:
+ is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
+ except TimeoutError:
+ is_correct = False
if not is_correct:
break
@@ -392,20 +371,31 @@ def match_answer(response):
return is_matched, response
-import math
-
def llm_check_answer(model_output: str, ground_truth: str, question: str) -> bool:
# use llm to check if the answer is correct
+ # Supports multiple endpoints for load balancing - separate URLs with commas
+ # e.g., MATH_LLM_JUDGE_URL="http://host1:8000,http://host2:8000,http://host3:8000"
+
+ import os
+ import random
- # url = "http://176.56.200.81:30000/v1/chat/completions"
- url = os.getenv("MATH_LLM_JUDGE_URL")
- if not url:
+ url_base_str = os.getenv("MATH_LLM_JUDGE_URL")
+ if not url_base_str:
raise ValueError("MATH_LLM_JUDGE_URL is not set")
-
+
+ # Support multiple endpoints separated by commas
+ endpoints = [url.strip() for url in url_base_str.split(",") if url.strip()]
+ if not endpoints:
+ raise ValueError("MATH_LLM_JUDGE_URL contains no valid endpoints")
+
+ # Randomly select an endpoint for load balancing
+ url_base = random.choice(endpoints)
+ url = url_base.rstrip("/") + "/v1/chat/completions"
+
prompt = input_template.format(QUESTION=question, STUDENT_ANSWER=model_output, REFERENCE_ANSWER=ground_truth)
-
+
data = {
- "model": "Qwen/Qwen2.5-32B-Instruct",
+ "model": "openai/gpt-oss-120b",
"messages": [{"role": "user", "content": prompt}],
}
response = requests.post(url, json=data)
@@ -423,7 +413,7 @@ def llm_check_answer(model_output: str, ground_truth: str, question: str) -> boo
def compute_score(model_output: str,
ground_truth: str,
extra_info: dict) -> bool:
- question = extra_info["question"]
+ question = extra_info["original_question"]
model_output = str(model_output)
ground_truth = str(ground_truth)
@@ -447,5 +437,4 @@ def compute_score(model_output: str,
if is_matched and not is_correct:
# use llm to check if the answer is correct
is_correct = llm_check_answer(extracted_model_output, ground_truth, question)
-
return is_correct, 1, extracted_model_output
diff --git a/verl/utils/reward_score/math_llm_judge/grader.py b/verl/utils/reward_score/math_llm_judge/grader.py
index 34f0c7f52..5eaa22679 100644
--- a/verl/utils/reward_score/math_llm_judge/grader.py
+++ b/verl/utils/reward_score/math_llm_judge/grader.py
@@ -92,9 +92,7 @@
- https://github.com/openai/prm800k
"""
-import contextlib
import re
-import signal
import math
from math import isclose
from typing import Union
@@ -102,6 +100,7 @@
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
+from verl.utils.py_functional import timeout_limit
def is_digit(s):
@@ -304,16 +303,16 @@ def math_equal(prediction: Union[bool, float, str],
except Exception:
pass
- return symbolic_equal(prediction, reference, tolerance, timeout)
+ return symbolic_equal(prediction, reference, tolerance)
-def symbolic_equal(a, b, tolerance, timeout=10.0):
+@timeout_limit(seconds=10)
+def symbolic_equal(a, b, tolerance):
def _parse(s):
for f in [parse_expr, parse_latex]:
try:
- with time_limit(timeout):
- return f(s)
+ return f(s)
except Exception:
pass
return s
@@ -322,39 +321,19 @@ def _parse(s):
b = _parse(b)
try:
- with time_limit(timeout):
- if simplify(a - b) == 0:
- return True
+ if simplify(a - b) == 0:
+ return True
except Exception:
pass
try:
- with time_limit(timeout):
- if isclose(N(a), N(b), rel_tol=tolerance):
- return True
+ if isclose(N(a), N(b), rel_tol=tolerance):
+ return True
except Exception:
pass
return False
-class TimeoutException(Exception):
- pass
-
-
-@contextlib.contextmanager
-def time_limit(seconds: float):
-
- def signal_handler(signum, frame):
- raise TimeoutException("Timed out!")
-
- signal.setitimer(signal.ITIMER_REAL, seconds)
- signal.signal(signal.SIGALRM, signal_handler)
- try:
- yield
- finally:
- signal.setitimer(signal.ITIMER_REAL, 0)
-
-
def format_intervals(prediction):
patterns = {
"Interval(": r"^Interval\((.*)\)$",
diff --git a/verl/utils/reward_score/naive_dapo.py b/verl/utils/reward_score/naive_dapo.py
index d26a1dd72..af4ba8bee 100644
--- a/verl/utils/reward_score/naive_dapo.py
+++ b/verl/utils/reward_score/naive_dapo.py
@@ -14,34 +14,18 @@
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
import re
-import signal
from typing import Optional
+import math
import sympy
-from pylatexenc import latex2text
+# from pylatexenc import latex2text
from sympy.parsing import sympy_parser
-import os
+from verl.utils.py_functional import timeout_limit
+
from .prime_math import math_normalize
from .prime_math.grader import math_equal
-class timeout:
-
- def __init__(self, seconds=1, error_message="Timeout"):
- self.seconds = seconds
- self.error_message = error_message
-
- def handle_timeout(self, signum, frame):
- raise TimeoutError(self.error_message)
-
- def __enter__(self):
- signal.signal(signal.SIGALRM, self.handle_timeout)
- signal.alarm(self.seconds)
-
- def __exit__(self, type, value, traceback):
- signal.alarm(0)
-
-
# Constants for normalization
SUBSTITUTIONS = [
("an ", ""),
@@ -103,10 +87,10 @@ def __exit__(self, type, value, traceback):
def normalize_final_answer(final_answer: str) -> str:
"""Normalize a final answer to a quantitative reasoning question.
-
+
Args:
final_answer: The answer string to normalize
-
+
Returns:
Normalized answer string
"""
@@ -148,48 +132,25 @@ def normalize_final_answer(final_answer: str) -> str:
TUPLE_CHARS = "()[]"
-def timeout(timeout_seconds: int = 8):
- if os.name == "posix":
- import signal
-
- def decorator(func):
-
- def handler(signum, frame):
- raise TimeoutError("Operation timed out!")
-
- def wrapper(*args, **kwargs):
- old_handler = signal.getsignal(signal.SIGALRM)
- signal.signal(signal.SIGALRM, handler)
- signal.alarm(timeout_seconds)
-
- try:
- return func(*args, **kwargs)
- finally:
- signal.alarm(0)
- signal.signal(signal.SIGALRM, old_handler)
-
- return wrapper
-
- return decorator
- else:
- raise NotImplementedError(f"Unsupported OS: {os.name}")
-
-
def _sympy_parse(expr: str):
"""Parses an expression with sympy."""
py_expr = expr.replace("^", "**")
return sympy_parser.parse_expr(
py_expr,
- transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)),
+ transformations=(
+ sympy_parser.standard_transformations
+ + (sympy_parser.implicit_multiplication_application,)
+ ),
)
+# @timeout(timeout_seconds=5)
def _parse_latex(expr: str) -> str:
"""Attempts to parse latex to an expression sympy can read."""
expr = expr.replace("\\tfrac", "\\frac")
expr = expr.replace("\\dfrac", "\\frac")
expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers.
- expr = latex2text.LatexNodes2Text().latex_to_text(expr)
+ # expr = latex2text.LatexNodes2Text().latex_to_text(expr)
# Replace the specific characters that this parser uses.
expr = expr.replace("√", "sqrt")
@@ -279,23 +240,23 @@ def _normalize(expr: str) -> str:
expr = expr.replace("trillion", "*10^12")
for unit in [
- "degree",
- "cm",
- "centimeter",
- "meter",
- "mile",
- "second",
- "minute",
- "hour",
- "day",
- "week",
- "month",
- "year",
- "foot",
- "feet",
- "inch",
- "yard",
- "liter",
+ "degree",
+ "cm",
+ "centimeter",
+ "meter",
+ "mile",
+ "second",
+ "minute",
+ "hour",
+ "day",
+ "week",
+ "month",
+ "year",
+ "foot",
+ "feet",
+ "inch",
+ "yard",
+ "liter",
]:
expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
expr = re.sub(f"\^ *\\\\circ", "", expr)
@@ -349,7 +310,7 @@ def should_allow_eval(expr: str):
return True
-@timeout(timeout_seconds=10)
+@timeout_limit(seconds=10)
def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
are_equal = False
try:
@@ -359,7 +320,7 @@ def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
simplified = sympy.simplify(sympy_diff)
if simplified == 0:
are_equal = True
- except:
+ except Exception:
pass
return are_equal
@@ -371,8 +332,12 @@ def split_tuple(expr: str):
expr = _strip_properly_formatted_commas(expr)
if len(expr) == 0:
return []
- if (len(expr) > 2 and expr[0] in TUPLE_CHARS and expr[-1] in TUPLE_CHARS and
- all([ch not in expr[1:-1] for ch in TUPLE_CHARS])):
+ if (
+ len(expr) > 2
+ and expr[0] in TUPLE_CHARS
+ and expr[-1] in TUPLE_CHARS
+ and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])
+ ):
elems = [elem.strip() for elem in expr[1:-1].split(",")]
else:
elems = [expr]
@@ -411,8 +376,10 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]:
ground_truth_elems = split_tuple(ground_truth_normalized)
given_elems = split_tuple(given_normalized)
- if len(ground_truth_elems) > 1 and (ground_truth_normalized[0] != given_normalized[0] or
- ground_truth_normalized[-1] != given_normalized[-1]):
+ if len(ground_truth_elems) > 1 and (
+ ground_truth_normalized[0] != given_normalized[0]
+ or ground_truth_normalized[-1] != given_normalized[-1]
+ ):
is_correct = False
elif len(ground_truth_elems) != len(given_elems):
is_correct = False
@@ -427,11 +394,13 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]:
is_correct = False
else:
is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
+ # is_correct = False
if not is_correct:
break
return is_correct, given_normalized
+
def _last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if idx < 0:
@@ -459,7 +428,7 @@ def _last_boxed_only_string(string):
if left_brace_idx is None or right_brace_idx is None:
return None
- return string[left_brace_idx + 1:right_brace_idx].strip()
+ return string[left_brace_idx + 1 : right_brace_idx].strip()
def match_answer(response):
@@ -471,21 +440,21 @@ def match_answer(response):
if ans_boxed:
is_matched = True
response = ans_boxed
-
+
return is_matched, response
+
import math
-def compute_score(solution_str: str,
- ground_truth: str,
- extra_info: dict) -> float:
+
+def compute_score(solution_str: str, ground_truth: str, extra_info: dict) -> float:
"""Compute the reward score for a solution. This draws heavily from the LLM-as-judge and PRIME reward functions
-
+
Args:
solution_str: The solution string
ground_truth: The ground truth answer
extra_info: dict with additional info for the score computation
-
+
Returns:
Reward score (1.0 for correct, -1.0 for incorrect)
"""
@@ -495,30 +464,29 @@ def compute_score(solution_str: str,
# Extract answer from generated output
is_matched, extracted_model_output = match_answer(model_output)
-
+
# TWK NOTE: WE REMOVED THE RESPONSE TRUNCATION FROM math_dapo.compute_score
# Verify the solution, first check simple comparisons.
correct, pred = grade_answer(extracted_model_output, ground_truth)
- if not correct:
+ if not correct:
try:
if "\\pi" in extracted_model_output or "\\pi" in ground_truth:
equivs = []
for pi in [math.pi, 3.14]:
- equivs.append(math_equal(extracted_model_output, ground_truth, tiemout=True, pi=pi))
+ equivs.append(math_equal(extracted_model_output, ground_truth, timeout=True, pi=pi))
correct = any(equivs)
else:
correct = math_equal(extracted_model_output, ground_truth, timeout=True)
except:
correct = False
-
# reward = 1.0 if correct else -1.0
- reward = 1.0 if correct else 0.
+ reward = 1.0 if correct else 0.0
acc = correct
return {
"score": reward,
"acc": acc,
- }
\ No newline at end of file
+ }
diff --git a/verl/utils/reward_score/prime_math/__init__.py b/verl/utils/reward_score/prime_math/__init__.py
index 8d9d273e3..50b02b3f4 100644
--- a/verl/utils/reward_score/prime_math/__init__.py
+++ b/verl/utils/reward_score/prime_math/__init__.py
@@ -24,7 +24,7 @@
import re
import sympy
-from pylatexenc import latex2text
+# from pylatexenc import latex2text
from sympy.parsing import sympy_parser
from verl.utils.py_functional import timeout_limit
diff --git a/verl/utils/reward_score/prime_math/grader.py b/verl/utils/reward_score/prime_math/grader.py
index 72bb749f2..c2e0ed7e0 100644
--- a/verl/utils/reward_score/prime_math/grader.py
+++ b/verl/utils/reward_score/prime_math/grader.py
@@ -102,9 +102,6 @@
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
-# verl related
-from verl.utils.py_functional import timeout_limit
-
def is_digit(s):
try:
@@ -325,11 +322,7 @@ def symbolic_equal(a, b, tolerance, timeout=10.0):
def _parse(s):
for f in [parse_expr, parse_latex]:
try:
- with timeout_limit(seconds=timeout):
- return f(s)
- except TimeoutError:
- print(f"Parsing timed out for {s}")
- continue
+ return f(s)
except Exception:
continue
return s
@@ -338,22 +331,14 @@ def _parse(s):
b = _parse(b)
try:
- with timeout_limit(seconds=timeout):
- if simplify(a - b) == 0:
- return True
- except TimeoutError:
- print(f"Simplification timed out for {a} - {b}")
- pass
+ if simplify(a - b) == 0:
+ return True
except Exception:
pass
try:
- with timeout_limit(seconds=timeout):
- if isclose(N(a), N(b), rel_tol=tolerance):
- return True
- except TimeoutError:
- print(f"Numerical evaluation timed out for {a}, {b}")
- pass
+ if isclose(N(a), N(b), rel_tol=tolerance):
+ return True
except Exception:
pass
return False
diff --git a/verl/utils/reward_score/puzzles_dataset.py b/verl/utils/reward_score/puzzles_dataset.py
index 72433430b..ca68c9dd6 100644
--- a/verl/utils/reward_score/puzzles_dataset.py
+++ b/verl/utils/reward_score/puzzles_dataset.py
@@ -2,23 +2,6 @@
import random
import ast
import operator
-import signal
-import contextlib
-
-class TimeoutException(Exception):
- pass
-
-@contextlib.contextmanager
-def time_limit(seconds: float):
- def signal_handler(signum, frame):
- raise TimeoutException("Timed out!")
-
- signal.setitimer(signal.ITIMER_REAL, seconds)
- signal.signal(signal.SIGALRM, signal_handler)
- try:
- yield
- finally:
- signal.setitimer(signal.ITIMER_REAL, 0)
def extract_solution(solution_str):
@@ -92,32 +75,31 @@ def compute_score(solution_str, ground_truth, extra_info: any = None, method='st
method: the method to extract the solution
timeout: maximum time in seconds to allow for computation
"""
- try:
- with time_limit(timeout):
- target = ground_truth.tolist() if not isinstance(ground_truth,list) else ground_truth
- predicted_arrangement = extract_solution(solution_str=solution_str)
-
- if predicted_arrangement is None:
- score = 0.0
-
- # Evaluate equation
- try:
- if isinstance(predicted_arrangement, list) and isinstance(target, list):
- edit_distance = compute_edit_distance(predicted_arrangement, target)
- max_possible_dist = max(len(predicted_arrangement), len(target))
- result = predicted_arrangement == target
- if result:
- score = 1.0
- elif method != 'strict':
- score = max(1.0 - (edit_distance / max_possible_dist))
- else:
- score = 0.0
- except Exception as e:
- score = 0.0
+ def _compute_with_timeout():
+ target = ground_truth.tolist() if not isinstance(ground_truth,list) else ground_truth
+ predicted_arrangement = extract_solution(solution_str=solution_str)
+
+ if predicted_arrangement is None:
+ return 0.0
- except TimeoutException:
- print("Computation timed out in puzzles_dataset")
- score = 0.0
+ # Evaluate equation
+ try:
+ if isinstance(predicted_arrangement, list) and isinstance(target, list):
+ edit_distance = compute_edit_distance(predicted_arrangement, target)
+ max_possible_dist = max(len(predicted_arrangement), len(target))
+ result = predicted_arrangement == target
+ if result:
+ return 1.0
+ elif method != 'strict':
+ return max(1.0 - (edit_distance / max_possible_dist))
+ else:
+ return 0.0
+ except Exception as e:
+ return 0.0
+
+ score = 0.0
+ try:
+ score = _compute_with_timeout()
except Exception as e:
print(f"Error in compute_score in puzzles_dataset: {e}")
score = 0.0
diff --git a/verl/utils/reward_score/reasoning_gym/__init__.py b/verl/utils/reward_score/reasoning_gym/__init__.py
index 64eba44c1..3d6e2a7a3 100644
--- a/verl/utils/reward_score/reasoning_gym/__init__.py
+++ b/verl/utils/reward_score/reasoning_gym/__init__.py
@@ -1,6 +1,14 @@
import reasoning_gym
import json
import re
+import gc
+from functools import lru_cache
+
+
+@lru_cache(maxsize=64)
+def _get_cached_scorer(task):
+ """Cache scorer objects to avoid recreating them for every sample."""
+ return reasoning_gym.get_score_answer_fn(task)
def compute_score(solution_str, ground_truth, extra_info=None, item=None):
"""
@@ -15,96 +23,95 @@ def compute_score(solution_str, ground_truth, extra_info=None, item=None):
Returns:
dict: {"score": float, "acc": float}
"""
- task = None
- entry = None
-
- # 1. Parse extra_info
- extra_info_dict = {}
- metadata = None
-
- if extra_info:
- if isinstance(extra_info, str):
- try:
- extra_info_dict = json.loads(extra_info)
- except Exception:
- extra_info_dict = {}
- else:
- extra_info_dict = extra_info
-
- # Get task first
- task = extra_info_dict.get("task")
- entry = extra_info_dict.get("entry")
+ def _compute_score_with_timeout():
+ task = None
+ entry = None
- # Handle metadata field if present
- if "metadata" in extra_info_dict:
- if isinstance(extra_info_dict["metadata"], str):
+ # 1. Parse extra_info
+ extra_info_dict = {}
+ metadata = None
+
+ if extra_info:
+ if isinstance(extra_info, str):
try:
- metadata = json.loads(extra_info_dict["metadata"])
+ extra_info_dict = json.loads(extra_info)
except Exception:
- metadata = {}
- elif isinstance(extra_info_dict["metadata"], dict):
- metadata = extra_info_dict["metadata"]
-
- # 2. Try to get from item (fallback - this is rarely used in actual training)
- if not task and item and isinstance(item, dict):
- task = item.get("ability")
-
- # 3. Try to get from ground_truth
- if not task and isinstance(ground_truth, dict):
- task = ground_truth.get("task")
- entry = ground_truth
-
- if not task:
- raise ValueError("task must be provided in extra_info, item, or ground_truth dict.")
-
- # 4. Get scoring function
- scorer = reasoning_gym.get_score_answer_fn(task)
-
- # 5. Get entry
- if entry is None:
- entry = {"answer": ground_truth}
-
- # Build metadata field, prioritizing extra_info metadata
- if isinstance(entry, dict):
- if "metadata" not in entry or not isinstance(entry["metadata"], dict):
- entry["metadata"] = {}
- if metadata is not None:
- entry["metadata"].update(metadata)
- if task is not None:
- entry["metadata"]["task"] = task
- entry["metadata"]["solution_str"] = solution_str
- entry["metadata"]["ground_truth"] = ground_truth
- if extra_info is not None:
- entry["metadata"]["extra_info"] = extra_info
- if item is not None:
- entry["metadata"]["item"] = item
-
- # 6. Extract clean answer from solution_str
- clean_answer = extract_answer_from_solution(solution_str)
-
- # 7. Scoring with task-specific fixes
- debug_log_path = "reasoning_gym_debug.log"
- try:
- with open(debug_log_path, "a", encoding="utf-8") as f:
- f.write("[DEBUG] solution_str: {}\n".format(solution_str))
- f.write("[DEBUG] clean_answer: {}\n".format(clean_answer))
- f.write("[DEBUG] ground_truth: {}\n".format(ground_truth))
- f.write("[DEBUG] task: {}\n".format(task))
- f.write("[DEBUG] metadata: {}\n".format(json.dumps(entry.get("metadata", {}), ensure_ascii=False, indent=2)))
-
+ extra_info_dict = {}
+ else:
+ extra_info_dict = extra_info
+
+ # Get task first
+ task = extra_info_dict.get("task")
+ entry = extra_info_dict.get("entry")
+
+ # Handle metadata field if present
+ if "metadata" in extra_info_dict:
+ if isinstance(extra_info_dict["metadata"], str):
+ try:
+ metadata = json.loads(extra_info_dict["metadata"])
+ except Exception:
+ metadata = {}
+ elif isinstance(extra_info_dict["metadata"], dict):
+ metadata = extra_info_dict["metadata"]
+
+ # 2. Try to get from item (fallback - this is rarely used in actual training)
+ if not task and item and isinstance(item, dict):
+ task = item.get("ability")
+
+ # 3. Try to get from ground_truth
+ if not task and isinstance(ground_truth, dict):
+ task = ground_truth.get("task")
+ entry = ground_truth
+
+ if not task:
+ raise ValueError("task must be provided in extra_info, item, or ground_truth dict.")
+
+ # 4. Get scoring function (cached to avoid recreating for every sample)
+ scorer = _get_cached_scorer(task)
+
+ # 5. Get entry
+ if entry is None:
+ entry = {"answer": ground_truth}
+
+ # Build metadata field, prioritizing extra_info metadata
+ if isinstance(entry, dict):
+ if "metadata" not in entry or not isinstance(entry["metadata"], dict):
+ entry["metadata"] = {}
+ if metadata is not None:
+ entry["metadata"].update(metadata)
+ if task is not None:
+ entry["metadata"]["task"] = task
+ entry["metadata"]["solution_str"] = solution_str
+ entry["metadata"]["ground_truth"] = ground_truth
+ if extra_info is not None:
+ entry["metadata"]["extra_info"] = extra_info
+ if item is not None:
+ entry["metadata"]["item"] = item
+
+ # 6. Extract clean answer from solution_str
+ clean_answer = extract_answer_from_solution(solution_str)
+
+ # 7. Scoring with task-specific fixes (debug logging removed for memory efficiency)
+ try:
# Get raw score from reasoning_gym using clean answer
raw_score = scorer(answer=clean_answer, entry=entry)
-
+
# Apply task-specific corrections for known issues
corrected_score = apply_task_specific_corrections(task, solution_str, ground_truth, raw_score)
-
- f.write("[DEBUG] raw_score: {}\n".format(raw_score))
- f.write("[DEBUG] corrected_score: {}\n".format(corrected_score))
-
- return {"score": float(corrected_score), "acc": float(corrected_score)}
+
+ return {"score": float(corrected_score), "acc": float(corrected_score)}
+ except Exception as e:
+ # Only print errors, don't write to file
+ print(f"Error in reasoning gym scoring: {e}")
+ return {"score": 0.0, "acc": 0.0}
+
+ try:
+ return _compute_score_with_timeout()
+ except TimeoutError:
+ print("Computation timed out in reasoning_gym")
+ return {"score": 0.0, "acc": 0.0}
except Exception as e:
- with open(debug_log_path, "a", encoding="utf-8") as f:
- f.write(f"Error in reasoning gym scoring: {e}\n")
+ print(f"Error in compute_score in reasoning_gym: {e}")
return {"score": 0.0, "acc": 0.0}
diff --git a/verl/utils/reward_score/stem_llm_judge/__init__.py b/verl/utils/reward_score/stem_llm_judge/__init__.py
index a0a419617..3aeb56805 100644
--- a/verl/utils/reward_score/stem_llm_judge/__init__.py
+++ b/verl/utils/reward_score/stem_llm_judge/__init__.py
@@ -3,13 +3,15 @@
Prerequisite:
- Set env var STEM_LLM_JUDGE_URL to an OpenAI-compatible /v1/chat/completions.
+ - Supports multiple endpoints for load balancing - separate URLs with commas:
+ export STEM_LLM_JUDGE_URL="http://host1:8000,http://host2:8000,http://host3:8000"
- Launch the service with your preferred model beforehand, e.g.
vllm serve TIGER-Lab/general-verifier
- export STEM_LLM_JUDGE_URL=http://127.0.0.1:8000/v1/chat/completions
+ export STEM_LLM_JUDGE_URL=http://127.0.0.1:8000
"""
-import os, re, requests
+import os, re, random, requests
from typing import Tuple
# # ------------ Prompt template ------------------------------------------------
@@ -35,9 +37,17 @@
# ------------ Core LLM call --------------------------------------------------
def _llm_judge(question: str, student: str, reference: str, verbose: bool = False) -> bool:
- url_base = os.getenv("STEM_LLM_JUDGE_URL")
- if not url_base:
+ url_base_str = os.getenv("STEM_LLM_JUDGE_URL")
+ if not url_base_str:
raise EnvironmentError("STEM_LLM_JUDGE_URL not set")
+
+ # Support multiple endpoints separated by commas for load balancing
+ endpoints = [url.strip() for url in url_base_str.split(",") if url.strip()]
+ if not endpoints:
+ raise EnvironmentError("STEM_LLM_JUDGE_URL contains no valid endpoints")
+
+ # Randomly select an endpoint for load balancing
+ url_base = random.choice(endpoints)
url = url_base.rstrip("/") + "/v1/chat/completions"
# prompt = JUDGE_TEMPLATE.format(
diff --git a/verl/utils/reward_score/synlogic/arrow_maze_verifier.py b/verl/utils/reward_score/synlogic/arrow_maze_verifier.py
index 14f5977f8..74f084cd9 100644
--- a/verl/utils/reward_score/synlogic/arrow_maze_verifier.py
+++ b/verl/utils/reward_score/synlogic/arrow_maze_verifier.py
@@ -43,101 +43,49 @@ def verify(self, data: Data, test_solution_str: str) -> bool:
@param test_solution_str: 测试答案字符串 (JSON格式的二维数组)
@return: 答案是否正确
"""
- test_answer_str = self.extract_answer(test_solution_str)
- if not test_answer_str:
- # print("答案为空,验证失败")
- return False
-
try:
- # 解析测试答案
- test_answer = json.loads(test_answer_str)
-
- # 获取原始迷宫
- question_grid = data.metadata["maze"]
-
- # 检查答案是否符合要求
- if not self._verify_grid_size(test_answer, question_grid):
- # print("答案网格大小与题目不匹配")
- with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f:
- f.write("test_solution_str: " + test_solution_str + '\n')
- f.write("test_answer_str: " + test_answer_str + '\n')
- f.write("question_grid: " + str(data.metadata["maze"]) + '\n')
- f.write("答案网格大小与题目不匹配" + '\n')
- f.write('-'*32 + '\n')
- return False
-
- if not self._verify_number_positions(test_answer, question_grid):
- # print("答案中数字位置或值与题目不匹配")
- with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f:
- f.write("test_solution_str: " + test_solution_str + '\n')
- f.write("test_answer_str: " + test_answer_str + '\n')
- f.write("question_grid: " + str(data.metadata["maze"]) + '\n')
- f.write("答案中数字位置或值与题目不匹配" + '\n')
- f.write('-'*32 + '\n')
+ test_answer_str = self.extract_answer(test_solution_str)
+ if not test_answer_str:
+ # print("答案为空,验证失败")
return False
+
+ try:
+ # 解析测试答案
+ test_answer = json.loads(test_answer_str)
- if not self._verify_all_blanks_filled(test_answer, question_grid):
- # print("答案中有空格未被填满")
- with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f:
- f.write("test_solution_str: " + test_solution_str + '\n')
- f.write("test_answer_str: " + test_answer_str + '\n')
- f.write("question_grid: " + str(data.metadata["maze"]) + '\n')
- f.write("答案中有空格未被填满" + '\n')
- f.write('-'*32 + '\n')
- return False
+ # 获取原始迷宫
+ question_grid = data.metadata["maze"]
- if not self._verify_arrow_symbols(test_answer):
- # print("答案中包含非法箭头符号")
- with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f:
- f.write("test_solution_str: " + test_solution_str + '\n')
- f.write("test_answer_str: " + test_answer_str + '\n')
- f.write("question_grid: " + str(data.metadata["maze"]) + '\n')
- f.write("答案中包含非法箭头符号" + '\n')
- f.write('-'*32 + '\n')
- return False
-
- if not self._verify_prefilled_arrows(test_answer, question_grid):
- # print("答案中预填箭头与题目不一致")
- with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f:
- f.write("test_solution_str: " + test_solution_str + '\n')
- f.write("test_answer_str: " + test_answer_str + '\n')
- f.write("question_grid: " + str(data.metadata["maze"]) + '\n')
- f.write("答案中预填箭头与题目不一致" + '\n')
- f.write('-'*32 + '\n')
- return False
+ # 检查答案是否符合要求
+ if not self._verify_grid_size(test_answer, question_grid):
+ return False
+
+ if not self._verify_number_positions(test_answer, question_grid):
+ return False
+
+ if not self._verify_all_blanks_filled(test_answer, question_grid):
+ return False
+
+ if not self._verify_arrow_symbols(test_answer):
+ return False
+
+ if not self._verify_prefilled_arrows(test_answer, question_grid):
+ return False
+
+ if not self._verify_arrow_rays(test_answer):
+ return False
+
+ if not self._verify_number_rays(test_answer):
+ return False
- if not self._verify_arrow_rays(test_answer):
- # print("答案中存在未被射线覆盖的箭头")
- with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f:
- f.write("test_solution_str: " + test_solution_str + '\n')
- f.write("test_answer_str: " + test_answer_str + '\n')
- f.write("question_grid: " + str(data.metadata["maze"]) + '\n')
- f.write("答案中存在未被射线覆盖的箭头" + '\n')
- f.write('-'*32 + '\n')
- return False
+ # 所有验证都通过
+ # print("验证通过!")
+ return True
- if not self._verify_number_rays(test_answer):
- # print("答案中数字的射线箭头串总数不符合要求")
- with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f:
- f.write("test_solution_str: " + test_solution_str + '\n')
- f.write("test_answer_str: " + test_answer_str + '\n')
- f.write("question_grid: " + str(data.metadata["maze"]) + '\n')
- f.write("答案中数字的射线箭头串总数不符合要求" + '\n')
- f.write('-'*32 + '\n')
+ except Exception as e:
return False
-
- # 所有验证都通过
- # print("验证通过!")
- return True
-
except Exception as e:
- # print(f"验证过程中出错: {e}")
- with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f:
- f.write("test_solution_str: " + test_solution_str + '\n')
- f.write("test_answer_str: " + test_answer_str + '\n')
- f.write("question_grid: " + str(data.metadata["maze"]) + '\n')
- f.write("验证过程中出错" + str(e) + '\n')
- f.write('-'*32 + '\n')
+ print(f"Verification error (ArrowMazeVerifier): {e}")
return False
def _verify_grid_size(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool:
diff --git a/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py b/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py
index 8edab88f2..3d84e7889 100644
--- a/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py
+++ b/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py
@@ -8,19 +8,21 @@ class BooleanExpressionsVerifier(Verifier):
"""
def verify(self, data: Data, test_answer: str):
try:
- test_answer = self.extract_answer(test_answer)
- if test_answer is None:
- return False
- # 提取所有字母(a-z和A-Z)
- test_answer_letters = re.findall(r'[a-zA-Z]', test_answer)
- ground_truth_letters = re.findall(r'[a-zA-Z]', data.answer)
- test_answer_letters = self.lower(test_answer_letters)
- ground_truth_letters = self.lower(ground_truth_letters)
- # 转换为集合进行比较
- test_set = set(test_answer_letters)
- ground_truth_set = set(ground_truth_letters)
-
- return test_set == ground_truth_set
+ def _verify_with_timeout():
+ test_answer_extracted = self.extract_answer(test_answer)
+ if test_answer_extracted is None:
+ return False
+ # 提取所有字母(a-z和A-Z)
+ test_answer_letters = re.findall(r'[a-zA-Z]', test_answer_extracted)
+ ground_truth_letters = re.findall(r'[a-zA-Z]', data.answer)
+ test_answer_letters = self.lower(test_answer_letters)
+ ground_truth_letters = self.lower(ground_truth_letters)
+ # 转换为集合进行比较
+ test_set = set(test_answer_letters)
+ ground_truth_set = set(ground_truth_letters)
+
+ return test_set == ground_truth_set
+ return _verify_with_timeout()
except Exception as e:
print("NOTE!!! parse error!!!! (BooleanExpressions)", e)
return False
diff --git a/verl/utils/reward_score/synlogic/campsite_verifier.py b/verl/utils/reward_score/synlogic/campsite_verifier.py
index 4aee1a5b2..4387c8f5a 100644
--- a/verl/utils/reward_score/synlogic/campsite_verifier.py
+++ b/verl/utils/reward_score/synlogic/campsite_verifier.py
@@ -11,35 +11,37 @@ class CampsiteVerifier(Verifier):
"""
def verify(self, data: Data, test_solution: str):
try:
- test_answer = self.extract_answer(test_solution)
- original_grid = data.metadata["grid"]
- row_constraints = data.metadata["row_constraints"]
- col_constraints = data.metadata["col_constraints"]
- n = data.metadata["n"]
- m = data.metadata["m"]
-
- if not test_answer:
- return False
-
- if len(test_answer) != n or any(len(row) != m for row in test_answer):
- return False
-
- if not self._check_trees_unchanged(original_grid, test_answer):
- return False
-
- if not self._check_row_constraints(test_answer, row_constraints):
- return False
-
- if not self._check_col_constraints(test_answer, col_constraints):
- return False
-
- if not self._check_tents_not_adjacent(test_answer):
- return False
-
- if not self._check_tent_tree_matching(test_answer):
- return False
-
- return True
+ def _verify_with_timeout():
+ test_answer = self.extract_answer(test_solution)
+ original_grid = data.metadata["grid"]
+ row_constraints = data.metadata["row_constraints"]
+ col_constraints = data.metadata["col_constraints"]
+ n = data.metadata["n"]
+ m = data.metadata["m"]
+
+ if not test_answer:
+ return False
+
+ if len(test_answer) != n or any(len(row) != m for row in test_answer):
+ return False
+
+ if not self._check_trees_unchanged(original_grid, test_answer):
+ return False
+
+ if not self._check_row_constraints(test_answer, row_constraints):
+ return False
+
+ if not self._check_col_constraints(test_answer, col_constraints):
+ return False
+
+ if not self._check_tents_not_adjacent(test_answer):
+ return False
+
+ if not self._check_tent_tree_matching(test_answer):
+ return False
+
+ return True
+ return _verify_with_timeout()
except Exception as e:
print(f"Verification error (Campsite): {e}")
diff --git a/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py
index 3bfbdac76..3f426a71b 100644
--- a/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py
+++ b/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py
@@ -16,39 +16,41 @@ def verify(self, data: Data, test_answer: str):
@return: 回答是否正确的布尔值
"""
try:
- test_answer = self.extract_answer(test_solution=test_answer)
- # 获取正确答案
- if data.metadata["is_valid"]:
- correct_answer = "-1" # 合法序列对应-1
- else:
- correct_answer = str(data.metadata["first_error_pos"])
-
- # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'")
-
- # 清理和标准化答案
- test_answer = test_answer.strip()
-
- # 检查-1答案(合法序列)
- if correct_answer == "-1":
- # 如果正确答案是-1(合法序列),只接受-1作为回答
- if test_answer == "-1":
- is_correct = True
+ def _verify_with_timeout():
+ test_answer_extracted = self.extract_answer(test_solution=test_answer)
+ # 获取正确答案
+ if data.metadata["is_valid"]:
+ correct_answer = "-1" # 合法序列对应-1
else:
- is_correct = False
- else:
- # 正确答案是位置数字,需要验证模型回答也是相同数字
- try:
- is_correct = (int(test_answer) == int(correct_answer))
- except (ValueError, TypeError):
- # 如果模型回答不是有效数字,验证失败
- is_correct = False
-
- # if is_correct:
- # print("验证结果: 正确")
- # else:
- # print("验证结果: 错误")
+ correct_answer = str(data.metadata["first_error_pos"])
+
+ # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'")
+
+ # 清理和标准化答案
+ test_answer_clean = test_answer_extracted.strip()
+
+ # 检查-1答案(合法序列)
+ if correct_answer == "-1":
+ # 如果正确答案是-1(合法序列),只接受-1作为回答
+ if test_answer_clean == "-1":
+ is_correct = True
+ else:
+ is_correct = False
+ else:
+ # 正确答案是位置数字,需要验证模型回答也是相同数字
+ try:
+ is_correct = (int(test_answer_clean) == int(correct_answer))
+ except (ValueError, TypeError):
+ # 如果模型回答不是有效数字,验证失败
+ is_correct = False
- return is_correct
+ # if is_correct:
+ # print("验证结果: 正确")
+ # else:
+ # print("验证结果: 错误")
+
+ return is_correct
+ return _verify_with_timeout()
except Exception as e:
print(f"Verification error (DyckLanguageErrors): {e}")
diff --git a/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py
index 03f2b95f9..dbf4d7c69 100644
--- a/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py
+++ b/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py
@@ -16,39 +16,41 @@ def verify(self, data: Data, test_answer: str):
@return: 回答是否正确的布尔值
"""
try:
- test_answer = self.extract_answer(test_solution=test_answer)
- # 获取元数据中的正确答案
- correct_indices = data.metadata["error_indices"]
- # 格式化为正确的答案字符串格式
- expected_answer = self._format_answer(correct_indices)
-
- # print(f"验证: 模型答案='{test_answer}', 正确答案='{expected_answer}'")
-
- # 检查不明确的答案
- if "不确定" in test_answer or "不知道" in test_answer or "unclear" in test_answer.lower():
- # print("验证结果: 错误")
- return False
-
- # 清理模型答案,允许一定的格式变化
- cleaned_test_answer = self._standardize_answer(test_answer)
-
- if not correct_indices and (cleaned_test_answer == "" or cleaned_test_answer.lower() in ["无问题", "no", "无错误", "no error", "no errors", "no mistakes", "all correct"]):
- # 如果没有错误,且模型回答是空字符串或表示无问题,则正确
- is_correct = True
- else:
- # 将两个答案转换为数字集合进行比较
- test_error_indices = self._extract_error_indices(cleaned_test_answer)
- expected_error_indices = set(correct_indices)
+ def _verify_with_timeout():
+ test_answer_extracted = self.extract_answer(test_solution=test_answer)
+ # 获取元数据中的正确答案
+ correct_indices = data.metadata["error_indices"]
+ # 格式化为正确的答案字符串格式
+ expected_answer = self._format_answer(correct_indices)
- # 检查两个集合是否相同
- is_correct = test_error_indices == expected_error_indices
-
- # if is_correct:
- # print("验证结果: 正确")
- # else:
- # print("验证结果: 错误")
+ # print(f"验证: 模型答案='{test_answer}', 正确答案='{expected_answer}'")
+
+ # 检查不明确的答案
+ if "不确定" in test_answer_extracted or "不知道" in test_answer_extracted or "unclear" in test_answer_extracted.lower():
+ # print("验证结果: 错误")
+ return False
+
+ # 清理模型答案,允许一定的格式变化
+ cleaned_test_answer = self._standardize_answer(test_answer_extracted)
+
+ if not correct_indices and (cleaned_test_answer == "" or cleaned_test_answer.lower() in ["无问题", "no", "无错误", "no error", "no errors", "no mistakes", "all correct"]):
+ # 如果没有错误,且模型回答是空字符串或表示无问题,则正确
+ is_correct = True
+ else:
+ # 将两个答案转换为数字集合进行比较
+ test_error_indices = self._extract_error_indices(cleaned_test_answer)
+ expected_error_indices = set(correct_indices)
+
+ # 检查两个集合是否相同
+ is_correct = test_error_indices == expected_error_indices
- return is_correct
+ # if is_correct:
+ # print("验证结果: 正确")
+ # else:
+ # print("验证结果: 错误")
+
+ return is_correct
+ return _verify_with_timeout()
except Exception as e:
print(f"Verification error (DyckLanguageReasoningErrors): {e}")
diff --git a/verl/utils/reward_score/synlogic/dyck_language_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_verifier.py
index e986f66c5..1b30d2b11 100644
--- a/verl/utils/reward_score/synlogic/dyck_language_verifier.py
+++ b/verl/utils/reward_score/synlogic/dyck_language_verifier.py
@@ -16,23 +16,25 @@ def verify(self, data: Data, test_answer: str) -> bool:
@return: 回答是否正确的布尔值
"""
try:
- # 获取元数据中的完整序列
- full_sequence = data.metadata["full_sequence"]
-
- # print(f"验证: 模型答案='{test_answer}', 完整序列='{full_sequence}'")
-
- # 从模型回答中提取答案
- extracted_answer = self.extract_answer(test_answer)
-
- # 检查答案是否完全匹配
- is_correct = (extracted_answer == full_sequence)
-
- # if is_correct:
- # print("验证结果: 正确")
- # else:
- # print("验证结果: 错误")
+ def _verify_with_timeout():
+ # 获取元数据中的完整序列
+ full_sequence = data.metadata["full_sequence"]
+
+ # print(f"验证: 模型答案='{test_answer}', 完整序列='{full_sequence}'")
+
+ # 从模型回答中提取答案
+ extracted_answer = self.extract_answer(test_answer)
+
+ # 检查答案是否完全匹配
+ is_correct = (extracted_answer == full_sequence)
- return is_correct
+ # if is_correct:
+ # print("验证结果: 正确")
+ # else:
+ # print("验证结果: 错误")
+
+ return is_correct
+ return _verify_with_timeout()
except Exception as e:
print(f"Verification error (DyckLanguage): {e}")
diff --git a/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py b/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py
index b4c8bdb25..aea853c4d 100644
--- a/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py
+++ b/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py
@@ -26,19 +26,25 @@ def verify(self, data: Data, test_answer: str) -> bool:
@param test_answer: The answer provided by the LLM to verify
@return: bool indicating whether the answer is correct
"""
- # Extract the expected answer from the Data object
- expected_answer = data.answer if data and hasattr(data, 'answer') else ""
-
- # For empty strings, compare directly
- if not expected_answer and not test_answer:
- return True
-
- # Extract and normalize both answers
- normalized_expected = self._extract_answer(expected_answer)
- normalized_test = self._extract_answer(test_answer)
-
- # Direct comparison of normalized answers
- return normalized_expected == normalized_test
+ try:
+ def _verify_with_timeout():
+ # Extract the expected answer from the Data object
+ expected_answer = data.answer if data and hasattr(data, 'answer') else ""
+
+ # For empty strings, compare directly
+ if not expected_answer and not test_answer:
+ return True
+
+ # Extract and normalize both answers
+ normalized_expected = self._extract_answer(expected_answer)
+ normalized_test = self._extract_answer(test_answer)
+
+ # Direct comparison of normalized answers
+ return normalized_expected == normalized_test
+ return _verify_with_timeout()
+ except Exception as e:
+ # print(f"Verification error (BuggyTable): {e}")
+ return False
def _is_raw_numeric_answer(self, value: str) -> bool:
"""
diff --git a/verl/utils/reward_score/synlogic/goods_exchange_verifier.py b/verl/utils/reward_score/synlogic/goods_exchange_verifier.py
index 922e92689..31392d024 100644
--- a/verl/utils/reward_score/synlogic/goods_exchange_verifier.py
+++ b/verl/utils/reward_score/synlogic/goods_exchange_verifier.py
@@ -15,28 +15,30 @@ def verify(self, data: Data, test_solution: str):
@return: 回答是否正确的布尔值
"""
try:
- test_answer = self.extract_answer(test_solution)
- # 获取元数据中的正确答案
- correct_answer = data.metadata["owns_after"]
-
- # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'")
-
- # 解析模型答案
- model_ownership = self._parse_answer(test_answer)
- # 解析正确答案
- correct_ownership = self._parse_answer(correct_answer)
-
- # 比较两个答案是否完全一致
- is_correct = self._compare_answers(model_ownership, correct_ownership)
-
- # if is_correct:
- # print("验证结果: 正确")
- # else:
- # print("验证结果: 错误")
- # # 打印详细的不匹配信息
- # self._print_difference(model_ownership, correct_ownership)
+ def _verify_with_timeout():
+ test_answer = self.extract_answer(test_solution)
+ # 获取元数据中的正确答案
+ correct_answer = data.metadata["owns_after"]
- return is_correct
+ # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'")
+
+ # 解析模型答案
+ model_ownership = self._parse_answer(test_answer)
+ # 解析正确答案
+ correct_ownership = self._parse_answer(correct_answer)
+
+ # 比较两个答案是否完全一致
+ is_correct = self._compare_answers(model_ownership, correct_ownership)
+
+ # if is_correct:
+ # print("验证结果: 正确")
+ # else:
+ # print("验证结果: 错误")
+ # # 打印详细的不匹配信息
+ # self._print_difference(model_ownership, correct_ownership)
+
+ return is_correct
+ return _verify_with_timeout()
except Exception as e:
print(f"Verification error (GoodsExchange): {e}")
diff --git a/verl/utils/reward_score/synlogic/math_path_verifier.py b/verl/utils/reward_score/synlogic/math_path_verifier.py
index d0df1c4a8..d447a5c92 100644
--- a/verl/utils/reward_score/synlogic/math_path_verifier.py
+++ b/verl/utils/reward_score/synlogic/math_path_verifier.py
@@ -18,59 +18,66 @@ def verify(self, data: Data, test_answer: str):
@return: 回答是否正确的布尔值
"""
try:
- test_answer = self.extract_answer(test_solution=test_answer)
- except Exception as e:
- print(f"NOTE!!! parse error!!!! (MathPath): {e}")
- return False
+ def _verify_with_timeout():
+ try:
+ test_answer_extracted = self.extract_answer(test_solution=test_answer)
+ except Exception as e:
+ # print(f"NOTE!!! parse error!!!! (MathPath): {e}")
+ return False
- try:
- # 解析元数据
- metadata = data.metadata
- ref_expr = metadata["ref_expr"]
- query_expr = metadata["query_expr"]
+ try:
+ # 解析元数据
+ metadata = data.metadata
+ ref_expr = metadata["ref_expr"]
+ query_expr = metadata["query_expr"]
- # 验证数字是否被篡改,数字是否在0-9之间。
- test_tmp = test_answer.replace(' ', '').strip()
- query_tmp = query_expr.replace(' ', '').strip()
- ref_tmp = ref_expr.replace(' ', '').strip()
- query_nums = [x for x in query_tmp if '0'<=x<='9' or x=='?']
- test_nums = [x for x in test_tmp if '0'<=x<='9']
- if len(query_nums)!=len(test_nums):
- # print(f"所填数字数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}")
- return False
- else:
- for ind, x in enumerate(query_nums):
- if x=='?':
- continue
- if x!=test_nums[ind]:
- # print(f"表达式数字被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}")
+ # 验证数字是否被篡改,数字是否在0-9之间。
+ test_tmp = test_answer_extracted.replace(' ', '').strip()
+ query_tmp = query_expr.replace(' ', '').strip()
+ ref_tmp = ref_expr.replace(' ', '').strip()
+ query_nums = [x for x in query_tmp if '0'<=x<='9' or x=='?']
+ test_nums = [x for x in test_tmp if '0'<=x<='9']
+ if len(query_nums)!=len(test_nums):
+ # print(f"所填数字数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}")
return False
+ else:
+ for ind, x in enumerate(query_nums):
+ if x=='?':
+ continue
+ if x!=test_nums[ind]:
+ # print(f"表达式数字被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}")
+ return False
+
+ query_symbols = [x for x in query_tmp if x in ['+', '-', '*', '/', '%']]
+ test_symbols = [x for x in test_tmp if x in ['+', '-', '*', '/', '%']]
+ if len(query_symbols)!=len(test_symbols):
+ # print(f"表达式运算符号数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}")
+ return False
+ else:
+ for ind, x in enumerate(query_symbols):
+ if x!=test_symbols[ind]:
+ # print(f"表达式运算符号被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}")
+ return False
- query_symbols = [x for x in query_tmp if x in ['+', '-', '*', '/', '%']]
- test_symbols = [x for x in test_tmp if x in ['+', '-', '*', '/', '%']]
- if len(query_symbols)!=len(test_symbols):
- # print(f"表达式运算符号数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}")
- return False
- else:
- for ind, x in enumerate(query_symbols):
- if x!=test_symbols[ind]:
- # print(f"表达式运算符号被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}")
+ # 验证回答中的等式是否成立
+ try:
+ tmp = test_tmp.replace('=', '==')
+ if not eval(tmp):
+ # print(f"等式不成立!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}")
+ return False
+ except:
+ # print(f"运算表达式错误!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}")
return False
-
- # 验证回答中的等式是否成立
- try:
- tmp = test_tmp.replace('=', '==')
- if not eval(tmp):
- # print(f"等式不成立!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}")
+
+
+ # 所有检查都通过
+ # print("验证结果: 正确")
+ return True
+
+ except Exception as e:
+ print(f"Verification error (MathPath): {e}")
return False
- except:
- # print(f"运算表达式错误!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}")
- return False
-
-
- # 所有检查都通过
- # print("验证结果: 正确")
- return True
+ return _verify_with_timeout()
except Exception as e:
print(f"Verification error (MathPath): {e}")
diff --git a/verl/utils/reward_score/synlogic/minesweeper_verifier.py b/verl/utils/reward_score/synlogic/minesweeper_verifier.py
index 1b2791403..8bb0240cd 100644
--- a/verl/utils/reward_score/synlogic/minesweeper_verifier.py
+++ b/verl/utils/reward_score/synlogic/minesweeper_verifier.py
@@ -12,17 +12,19 @@ class MinesweeperVerifier(Verifier):
"""
def verify(self, data: Data, test_solution: str, **kwargs):
try:
- # 从解答中提取地雷坐标
- predicted_mines = self.extract_answer(test_solution)
-
- # 从metadata中获取确定性地雷坐标
- expected_mines = data.metadata["current_mines"]
-
- # 验证提取的坐标是否正确
- if set(tuple(mine) for mine in predicted_mines) == set(tuple(mine) for mine in expected_mines):
- return True
-
- return False
+ def _verify_with_timeout():
+ # 从解答中提取地雷坐标
+ predicted_mines = self.extract_answer(test_solution)
+
+ # 从metadata中获取确定性地雷坐标
+ expected_mines = data.metadata["current_mines"]
+
+ # 验证提取的坐标是否正确
+ if set(tuple(mine) for mine in predicted_mines) == set(tuple(mine) for mine in expected_mines):
+ return True
+
+ return False
+ return _verify_with_timeout()
except Exception as e:
# 如果验证过程中发生任何错误,返回False
diff --git a/verl/utils/reward_score/synlogic/norinori_verifier.py b/verl/utils/reward_score/synlogic/norinori_verifier.py
index 95cc98ed1..65793b666 100644
--- a/verl/utils/reward_score/synlogic/norinori_verifier.py
+++ b/verl/utils/reward_score/synlogic/norinori_verifier.py
@@ -24,54 +24,56 @@ def verify(self, data: Data, test_solution: str):
bool -- 答案是否正确
"""
try:
- # 从游戏数据中获取区域网格
- region_grid = data.metadata["region_grid"]
- n = len(region_grid)
-
- # 解析答案
- dominoes = self._parse_answer(test_solution)
- if dominoes is None:
- return False
-
- # 检查多米诺形状
- if not self._check_domino_shapes(dominoes):
- return False
-
- # 创建覆盖网格
- covered = [[False for _ in range(n)] for _ in range(n)]
- for domino in dominoes:
- for i, j in domino:
- # 转换为0-indexed
- i -= 1
- j -= 1
- if i < 0 or i >= n or j < 0 or j >= n:
- return False # 坐标超出范围
- if covered[i][j]:
- return False # 格子被多次覆盖
- covered[i][j] = True
-
- # 检查多米诺之间是否相邻
- if not self._check_domino_adjacency(dominoes, n):
- return False
-
- # 检查每个区域是否恰好有两个格子被覆盖
- region_coverage = defaultdict(int)
- for i in range(n):
- for j in range(n):
- if covered[i][j] and region_grid[i][j] != "X":
- region_coverage[region_grid[i][j]] += 1
-
- for region, count in region_coverage.items():
- if count != 2:
+ def _verify_with_timeout():
+ # 从游戏数据中获取区域网格
+ region_grid = data.metadata["region_grid"]
+ n = len(region_grid)
+
+ # 解析答案
+ dominoes = self._parse_answer(test_solution)
+ if dominoes is None:
return False
-
- # 检查所有阴影格子是否被覆盖
- for i in range(n):
- for j in range(n):
- if region_grid[i][j] == "X" and not covered[i][j]:
+
+ # 检查多米诺形状
+ if not self._check_domino_shapes(dominoes):
+ return False
+
+ # 创建覆盖网格
+ covered = [[False for _ in range(n)] for _ in range(n)]
+ for domino in dominoes:
+ for i, j in domino:
+ # 转换为0-indexed
+ i -= 1
+ j -= 1
+ if i < 0 or i >= n or j < 0 or j >= n:
+ return False # 坐标超出范围
+ if covered[i][j]:
+ return False # 格子被多次覆盖
+ covered[i][j] = True
+
+ # 检查多米诺之间是否相邻
+ if not self._check_domino_adjacency(dominoes, n):
+ return False
+
+ # 检查每个区域是否恰好有两个格子被覆盖
+ region_coverage = defaultdict(int)
+ for i in range(n):
+ for j in range(n):
+ if covered[i][j] and region_grid[i][j] != "X":
+ region_coverage[region_grid[i][j]] += 1
+
+ for region, count in region_coverage.items():
+ if count != 2:
return False
- return True
+ # 检查所有阴影格子是否被覆盖
+ for i in range(n):
+ for j in range(n):
+ if region_grid[i][j] == "X" and not covered[i][j]:
+ return False
+
+ return True
+ return _verify_with_timeout()
except Exception as e:
print(f"Verification error (Norinori): {e}")
return False
diff --git a/verl/utils/reward_score/synlogic/number_wall_verifier.py b/verl/utils/reward_score/synlogic/number_wall_verifier.py
index 541c0720a..918f96b7e 100644
--- a/verl/utils/reward_score/synlogic/number_wall_verifier.py
+++ b/verl/utils/reward_score/synlogic/number_wall_verifier.py
@@ -11,53 +11,55 @@ class NumberWallVerifier(Verifier):
"""
def verify(self, data: Data, test_solution: str, **kwargs):
try:
- # 提取答案网格
- solution_grid = self.extract_answer(test_solution)
- if not solution_grid:
- # print("Failed to extract solution grid")
- return False
-
- # 提取元数据
- original_grid = data.metadata["grid"]
- n = data.metadata["n"]
-
- # 检查网格尺寸
- if len(solution_grid) != n:
- # print(f"Solution grid has incorrect number of rows: {len(solution_grid)} != {n}")
- return False
+ def _verify_with_timeout():
+ # 提取答案网格
+ solution_grid = self.extract_answer(test_solution)
+ if not solution_grid:
+ # print("Failed to extract solution grid")
+ return False
+
+ # 提取元数据
+ original_grid = data.metadata["grid"]
+ n = data.metadata["n"]
- for row in solution_grid:
- if len(row) != n:
- # print(f"Solution grid has incorrect number of columns: {len(row)} != {n}")
+ # 检查网格尺寸
+ if len(solution_grid) != n:
+ # print(f"Solution grid has incorrect number of rows: {len(solution_grid)} != {n}")
return False
- # 检查每个单元格只包含数字、"X"或"A"
- for cell in row:
- if not (isinstance(cell, int) or cell in ["X", "A"]):
- # print(f"Invalid cell content: {cell}")
+ for row in solution_grid:
+ if len(row) != n:
+ # print(f"Solution grid has incorrect number of columns: {len(row)} != {n}")
return False
-
- # 检查原始数字是否保留
- if not self._check_original_numbers(original_grid, solution_grid):
- # print("Original numbers not preserved")
- return False
-
- # 检查墙壁布局是否有效(没有2×2或更大的连续墙块)
- if not self._check_wall_layout(solution_grid):
- # print("Invalid wall layout (2x2 or larger continuous wall blocks found)")
- return False
-
- # 检查岛屿划分是否有效
- if not self._check_islands(solution_grid):
- # print("Invalid island division")
- return False
-
- # 检查是否有斜线边
- if not self._check_diagonal_borders(solution_grid):
- # print("Invalid solution: islands have diagonal borders")
- return False
+
+ # 检查每个单元格只包含数字、"X"或"A"
+ for cell in row:
+ if not (isinstance(cell, int) or cell in ["X", "A"]):
+ # print(f"Invalid cell content: {cell}")
+ return False
- return True
+ # 检查原始数字是否保留
+ if not self._check_original_numbers(original_grid, solution_grid):
+ # print("Original numbers not preserved")
+ return False
+
+ # 检查墙壁布局是否有效(没有2×2或更大的连续墙块)
+ if not self._check_wall_layout(solution_grid):
+ # print("Invalid wall layout (2x2 or larger continuous wall blocks found)")
+ return False
+
+ # 检查岛屿划分是否有效
+ if not self._check_islands(solution_grid):
+ # print("Invalid island division")
+ return False
+
+ # 检查是否有斜线边
+ if not self._check_diagonal_borders(solution_grid):
+ # print("Invalid solution: islands have diagonal borders")
+ return False
+
+ return True
+ return _verify_with_timeout()
except Exception as e:
# 如果验证过程中发生任何错误,返回False
diff --git a/verl/utils/reward_score/synlogic/numbrix_verifier.py b/verl/utils/reward_score/synlogic/numbrix_verifier.py
index f142bc7b7..60fc4aaba 100644
--- a/verl/utils/reward_score/synlogic/numbrix_verifier.py
+++ b/verl/utils/reward_score/synlogic/numbrix_verifier.py
@@ -11,54 +11,56 @@ class NumbrixVerifier(Verifier):
"""
def verify(self, data: Data, test_solution: str):
try:
- # 提取答案网格
- test_grid = self.extract_answer(test_solution)
- if not test_grid:
- return False
-
- # 获取原始谜题和网格大小
- original_grid = data.metadata["grid"]
- n = len(original_grid)
- n_squared = n * n
-
- # 检查网格大小是否正确
- if len(test_grid) != n or any(len(row) != n for row in test_grid):
- return False
-
- # 检查是否包含所有数字 1 到 n²
- flattened_grid = [cell for row in test_grid for cell in row]
- if sorted(flattened_grid) != list(range(1, n_squared + 1)):
- return False
-
- # 检查是否保留了原始提示数字
- for i in range(n):
- for j in range(n):
- if original_grid[i][j] != "X" and test_grid[i][j] != original_grid[i][j]:
- return False
-
- # 检查连续数字是否正交相邻
- for num in range(1, n_squared):
- # 找到当前数字的位置
- current_pos = None
- next_pos = None
- for i in range(n):
- for j in range(n):
- if test_grid[i][j] == num:
- current_pos = (i, j)
- elif test_grid[i][j] == num + 1:
- next_pos = (i, j)
+ def _verify_with_timeout():
+ # 提取答案网格
+ test_grid = self.extract_answer(test_solution)
+ if not test_grid:
+ return False
- if current_pos is None or next_pos is None:
+ # 获取原始谜题和网格大小
+ original_grid = data.metadata["grid"]
+ n = len(original_grid)
+ n_squared = n * n
+
+ # 检查网格大小是否正确
+ if len(test_grid) != n or any(len(row) != n for row in test_grid):
return False
- # 检查是否正交相邻(曼哈顿距离为1)
- i1, j1 = current_pos
- i2, j2 = next_pos
- manhattan_distance = abs(i1 - i2) + abs(j1 - j2)
- if manhattan_distance != 1:
+ # 检查是否包含所有数字 1 到 n²
+ flattened_grid = [cell for row in test_grid for cell in row]
+ if sorted(flattened_grid) != list(range(1, n_squared + 1)):
return False
-
- return True
+
+ # 检查是否保留了原始提示数字
+ for i in range(n):
+ for j in range(n):
+ if original_grid[i][j] != "X" and test_grid[i][j] != original_grid[i][j]:
+ return False
+
+ # 检查连续数字是否正交相邻
+ for num in range(1, n_squared):
+ # 找到当前数字的位置
+ current_pos = None
+ next_pos = None
+ for i in range(n):
+ for j in range(n):
+ if test_grid[i][j] == num:
+ current_pos = (i, j)
+ elif test_grid[i][j] == num + 1:
+ next_pos = (i, j)
+
+ if current_pos is None or next_pos is None:
+ return False
+
+ # 检查是否正交相邻(曼哈顿距离为1)
+ i1, j1 = current_pos
+ i2, j2 = next_pos
+ manhattan_distance = abs(i1 - i2) + abs(j1 - j2)
+ if manhattan_distance != 1:
+ return False
+
+ return True
+ return _verify_with_timeout()
except Exception as e:
print(f"Verification error (Numbrix): {e}")
return False
diff --git a/verl/utils/reward_score/synlogic/object_counting_verifier.py b/verl/utils/reward_score/synlogic/object_counting_verifier.py
index e63e4fd97..733d23332 100644
--- a/verl/utils/reward_score/synlogic/object_counting_verifier.py
+++ b/verl/utils/reward_score/synlogic/object_counting_verifier.py
@@ -9,17 +9,19 @@ class ObjectCountingVerifier(Verifier):
"""
def verify(self, data: Data, test_answer: str):
try:
- ground_truth = int(data.answer)
- parsed_answer = self.extract_answer(test_answer)
- with open("solution_str_OC.txt", "a") as f:
- f.write("data.answer: " + data.answer + '\n')
- f.write("test_answer: " + test_answer + '\n')
- f.write("parsed_answer" + parsed_answer + '\n')
- f.write('-'*32 + '\n')
-
- if parsed_answer is None:
- return False
- return int(parsed_answer) == ground_truth
+ def _verify_with_timeout():
+ ground_truth = int(data.answer)
+ parsed_answer = self.extract_answer(test_answer)
+ # with open("solution_str_OC.txt", "a") as f:
+ # f.write("data.answer: " + data.answer + '\n')
+ # f.write("test_answer: " + test_answer + '\n')
+ # f.write("parsed_answer" + parsed_answer + '\n')
+ # f.write('-'*32 + '\n')
+
+ if parsed_answer is None:
+ return False
+ return int(parsed_answer) == ground_truth
+ return _verify_with_timeout()
except Exception as e:
print(f"NOTE!!! parse error!!!! (ObjectCounting): {e}")
diff --git a/verl/utils/reward_score/synlogic/object_properties_verifier.py b/verl/utils/reward_score/synlogic/object_properties_verifier.py
index f8ad06ccd..fba6aef2d 100644
--- a/verl/utils/reward_score/synlogic/object_properties_verifier.py
+++ b/verl/utils/reward_score/synlogic/object_properties_verifier.py
@@ -9,12 +9,16 @@ class ObjectPropertiesVerifier(Verifier):
"""
def verify(self, data: Data, test_answer: str):
try:
- ground_truth = int(data.answer)
- parsed_answer = int(self.extract_answer(test_answer))
-
- if parsed_answer is None:
- return False
- return int(parsed_answer) == ground_truth
+ def _verify_with_timeout():
+ ground_truth = int(data.answer)
+ parsed_answer_str = self.extract_answer(test_answer)
+
+ if parsed_answer_str is None:
+ return False
+
+ parsed_answer = int(parsed_answer_str)
+ return int(parsed_answer) == ground_truth
+ return _verify_with_timeout()
except Exception as e:
print(f"NOTE!!! parse error!!!! (ObjectProperties): {e}")
diff --git a/verl/utils/reward_score/synlogic/operation_verifier.py b/verl/utils/reward_score/synlogic/operation_verifier.py
index 7f730ce8a..cdb0a83f7 100644
--- a/verl/utils/reward_score/synlogic/operation_verifier.py
+++ b/verl/utils/reward_score/synlogic/operation_verifier.py
@@ -10,12 +10,18 @@ class OperationVerifier(Verifier):
"""
def verify(self, data: Data, test_answer: str):
try:
- ground_truth = math_verify.parse(data.answer)
- parsed_answer = math_verify.parse(test_answer)
-
- if parsed_answer is None:
- return False
- return math_verify.verify(parsed_answer, ground_truth)
+ def _verify_with_timeout():
+ ground_truth = math_verify.parse(data.answer, parsing_timeout=10)
+ parsed_answer = math_verify.parse(test_answer, parsing_timeout=10)
+
+ if parsed_answer is None:
+ return False
+ return math_verify.verify(parsed_answer, ground_truth)
+
+ return _verify_with_timeout()
+ except TimeoutError:
+ print("Parsing/Verification timed out (OperationVerifier)")
+ return False
except Exception as e:
print(f"NOTE!!! parse error!!!! (OperationVerifier): {e}")
return False
diff --git a/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py b/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py
index 524f5859b..ec5c8ef33 100644
--- a/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py
+++ b/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py
@@ -18,91 +18,93 @@ def verify(self, data: Data, test_solution: str):
@return: 回答是否正确的布尔值
"""
try:
- # 获取游戏元数据
- metadata = data.metadata
- n = metadata['n']
- top = metadata['top']
- bottom = metadata['bottom']
- left = metadata['left']
- right = metadata['right']
+ def _verify_with_timeout():
+ # 获取游戏元数据
+ metadata = data.metadata
+ n = metadata['n']
+ top = metadata['top']
+ bottom = metadata['bottom']
+ left = metadata['left']
+ right = metadata['right']
- self.n = n
- test_answer = self.extract_answer(test_solution)
-
- # print(f"验证: 游戏规模 {n}×{n}")
- # print(f"上方提示: {top}")
- # print(f"下方提示: {bottom}")
- # print(f"左侧提示: {left}")
- # print(f"右侧提示: {right}")
-
- # 使用提取好的网格数据
- grid = test_answer
-
- # 检查网格是否是字符串,如果是,说明提取失败
- if isinstance(grid, str):
- # print("无法提取有效网格")
- return False
+ self.n = n
+ test_answer = self.extract_answer(test_solution)
- print("提取的网格:")
- for row in grid:
- print(row)
-
- # 检查网格规模
- if len(grid) != n or any(len(row) != n for row in grid):
- # print(f"网格规模不正确,应为 {n}×{n}")
- return False
-
- # 检查数字范围 (1 到 n)
- for i in range(n):
- for j in range(n):
- if not isinstance(grid[i][j], int) or grid[i][j] < 1 or grid[i][j] > n:
- # print(f"位置 ({i+1},{j+1}) 的值 {grid[i][j]} 不在有效范围内 (1-{n})")
- return False
-
- # 检查每行唯一性
- for i in range(n):
- if len(set(grid[i])) != n:
- # print(f"第 {i+1} 行包含重复数字")
- return False
-
- # 检查每列唯一性
- for j in range(n):
- column = [grid[i][j] for i in range(n)]
- if len(set(column)) != n:
- # print(f"第 {j+1} 列包含重复数字")
- return False
-
- # 检查从上方观察
- for j in range(n):
- visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n)])
- if visible_count != top[j]:
- # print(f"从上方看第 {j+1} 列可见楼数为 {visible_count},应为 {top[j]}")
- return False
-
- # 检查从下方观察
- for j in range(n):
- visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n-1, -1, -1)])
- if visible_count != bottom[j]:
- # print(f"从下方看第 {j+1} 列可见楼数为 {visible_count},应为 {bottom[j]}")
- return False
-
- # 检查从左侧观察
- for i in range(n):
- visible_count = self._count_visible_skyscrapers(grid[i])
- if visible_count != left[i]:
- # print(f"从左侧看第 {i+1} 行可见楼数为 {visible_count},应为 {left[i]}")
+ # print(f"验证: 游戏规模 {n}×{n}")
+ # print(f"上方提示: {top}")
+ # print(f"下方提示: {bottom}")
+ # print(f"左侧提示: {left}")
+ # print(f"右侧提示: {right}")
+
+ # 使用提取好的网格数据
+ grid = test_answer
+
+ # 检查网格是否是字符串,如果是,说明提取失败
+ if isinstance(grid, str):
+ # print("无法提取有效网格")
return False
-
- # 检查从右侧观察
- for i in range(n):
- visible_count = self._count_visible_skyscrapers(grid[i][::-1])
- if visible_count != right[i]:
- # print(f"从右侧看第 {i+1} 行可见楼数为 {visible_count},应为 {right[i]}")
+
+ # print("提取的网格:")
+ # for row in grid:
+ # print(row)
+
+ # 检查网格规模
+ if len(grid) != n or any(len(row) != n for row in grid):
+ # print(f"网格规模不正确,应为 {n}×{n}")
return False
-
- # 所有检查通过
- # print("所有验证规则通过!")
- return True
+
+ # 检查数字范围 (1 到 n)
+ for i in range(n):
+ for j in range(n):
+ if not isinstance(grid[i][j], int) or grid[i][j] < 1 or grid[i][j] > n:
+ # print(f"位置 ({i+1},{j+1}) 的值 {grid[i][j]} 不在有效范围内 (1-{n})")
+ return False
+
+ # 检查每行唯一性
+ for i in range(n):
+ if len(set(grid[i])) != n:
+ # print(f"第 {i+1} 行包含重复数字")
+ return False
+
+ # 检查每列唯一性
+ for j in range(n):
+ column = [grid[i][j] for i in range(n)]
+ if len(set(column)) != n:
+ # print(f"第 {j+1} 列包含重复数字")
+ return False
+
+ # 检查从上方观察
+ for j in range(n):
+ visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n)])
+ if visible_count != top[j]:
+ # print(f"从上方看第 {j+1} 列可见楼数为 {visible_count},应为 {top[j]}")
+ return False
+
+ # 检查从下方观察
+ for j in range(n):
+ visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n-1, -1, -1)])
+ if visible_count != bottom[j]:
+ # print(f"从下方看第 {j+1} 列可见楼数为 {visible_count},应为 {bottom[j]}")
+ return False
+
+ # 检查从左侧观察
+ for i in range(n):
+ visible_count = self._count_visible_skyscrapers(grid[i])
+ if visible_count != left[i]:
+ # print(f"从左侧看第 {i+1} 行可见楼数为 {visible_count},应为 {left[i]}")
+ return False
+
+ # 检查从右侧观察
+ for i in range(n):
+ visible_count = self._count_visible_skyscrapers(grid[i][::-1])
+ if visible_count != right[i]:
+ # print(f"从右侧看第 {i+1} 行可见楼数为 {visible_count},应为 {right[i]}")
+ return False
+
+ # 所有检查通过
+ # print("所有验证规则通过!")
+ return True
+ return _verify_with_timeout()
except Exception as e:
print(f"Verification error (SkyscraperPuzzle): {e}")
diff --git a/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py b/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py
index abc165d5c..fbee5a7f3 100644
--- a/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py
+++ b/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py
@@ -8,14 +8,20 @@ class SpaceReasoningTreeVerifier(Verifier):
验证器用于空间推理树游戏的答案是否正确
"""
def verify(self, data: Data, test_answer: str):
- test_answer = self.extract_answer(test_answer)
- if test_answer is None:
+ try:
+ def _verify_with_timeout():
+ test_answer_extracted = self.extract_answer(test_answer)
+ if test_answer_extracted is None:
+ return False
+ test_answer_normalized = test_answer_extracted.replace(",", ",").replace(" ", "")
+ ground_truth = data.answer.replace(",", ",").replace(" ", "")
+ test_set = set(test_answer_normalized.split(","))
+ ground_truth_set = set(ground_truth.split(","))
+ return test_set == ground_truth_set
+ return _verify_with_timeout()
+ except Exception as e:
+ print(f"Verification error (SpaceReasoningTree): {e}")
return False
- test_answer = test_answer.replace(",", ",").replace(" ", "")
- ground_truth = data.answer.replace(",", ",").replace(" ", "")
- test_set = set(test_answer.split(","))
- ground_truth_set = set(ground_truth.split(","))
- return test_set == ground_truth_set
def extract_answer(self, answer_str):
# 先找到最后一个\boxed{的位置
diff --git a/verl/utils/reward_score/synlogic/space_reasoning_verifier.py b/verl/utils/reward_score/synlogic/space_reasoning_verifier.py
index 249f2dc08..e2a14756b 100644
--- a/verl/utils/reward_score/synlogic/space_reasoning_verifier.py
+++ b/verl/utils/reward_score/synlogic/space_reasoning_verifier.py
@@ -9,10 +9,16 @@ class SpaceReasoningVerifier(Verifier):
验证器用于空间推理游戏的答案是否正确
"""
def verify(self, data: Data, test_answer: str):
- test_answer = self.extract_answer(test_answer)
- if test_answer is None:
+ try:
+ def _verify_with_timeout():
+ test_answer_extracted = self.extract_answer(test_answer)
+ if test_answer_extracted is None:
+ return False
+ return test_answer_extracted.lower() == data.answer.lower()
+ return _verify_with_timeout()
+ except Exception as e:
+ print(f"Verification error (SpaceReasoning): {e}")
return False
- return test_answer.lower() == data.answer.lower()
def extract_answer(self, answer_str):
# 先找到最后一个\boxed{的位置
diff --git a/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py b/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py
index 98715e19a..161aff881 100644
--- a/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py
+++ b/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py
@@ -4,8 +4,6 @@
import json
import ast
-import re
-
class StarPlacementPuzzleVerifier(Verifier):
"""
星星放置游戏验证器,用于验证模型提供的解答是否正确
@@ -19,81 +17,83 @@ def verify(self, data: Data, test_solution: str):
@return: 回答是否正确的布尔值
"""
try:
- star_coords = self.extract_answer(test_solution)
- # 获取游戏元数据
- metadata = data.metadata
- n = metadata['n']
- k = metadata['k']
- region_grid = metadata['region_grid']
-
- # print(f"验证: 游戏规模 {n}×{n}, 每行/列/区域星星数量: {k}")
-
- # 检查是否有有效的星星坐标
- if not star_coords:
- # print("无法从回答中提取有效的星星坐标")
- return False
-
- # 创建一个表示星星位置的网格
- star_grid = [[0 for _ in range(n)] for _ in range(n)]
- for region, coords in star_coords.items():
- for coord in coords:
- row, col = coord
- if row < 0 or row >= n or col < 0 or col >= n:
- # print(f"无效坐标: ({row},{col}) - 超出网格范围")
- return False
- star_grid[row][col] = 1
-
- # 打印星星网格以便调试
- # print("星星网格:")
- for row in star_grid:
- print(''.join(['* ' if cell == 1 else '. ' for cell in row]))
-
- # 1. 检查每行是否有k颗星星
- for i in range(n):
- stars_in_row = sum(star_grid[i])
- if stars_in_row != k:
- # print(f"行 {i+1} 有 {stars_in_row} 颗星星,应该有 {k} 颗")
- return False
-
- # 2. 检查每列是否有k颗星星
- for j in range(n):
- stars_in_col = sum(star_grid[i][j] for i in range(n))
- if stars_in_col != k:
- # print(f"列 {j+1} 有 {stars_in_col} 颗星星,应该有 {k} 颗")
- return False
-
- # 3. 检查每个区域是否有k颗星星
- regions = {}
- for i in range(n):
- for j in range(n):
- region = region_grid[i][j]
- if region not in regions:
- regions[region] = []
- regions[region].append((i, j))
-
- for region, cells in regions.items():
- stars_in_region = sum(star_grid[i][j] for i, j in cells)
- if stars_in_region != k:
- # print(f"区域 {region} 有 {stars_in_region} 颗星星,应该有 {k} 颗")
+ def _verify_with_timeout():
+ star_coords = self.extract_answer(test_solution)
+ # 获取游戏元数据
+ metadata = data.metadata
+ n = metadata['n']
+ k = metadata['k']
+ region_grid = metadata['region_grid']
+
+ # print(f"验证: 游戏规模 {n}×{n}, 每行/列/区域星星数量: {k}")
+
+ # 检查是否有有效的星星坐标
+ if not star_coords:
+ # print("无法从回答中提取有效的星星坐标")
return False
-
- # 4. 检查星星是否互不相邻(水平、垂直、对角线)
- for i in range(n):
+
+ # 创建一个表示星星位置的网格
+ star_grid = [[0 for _ in range(n)] for _ in range(n)]
+ for region, coords in star_coords.items():
+ for coord in coords:
+ row, col = coord
+ if row < 0 or row >= n or col < 0 or col >= n:
+ # print(f"无效坐标: ({row},{col}) - 超出网格范围")
+ return False
+ star_grid[row][col] = 1
+
+ # 打印星星网格以便调试
+ # print("星星网格:")
+ # for row in star_grid:
+ # print(''.join(['* ' if cell == 1 else '. ' for cell in row]))
+
+ # 1. 检查每行是否有k颗星星
+ for i in range(n):
+ stars_in_row = sum(star_grid[i])
+ if stars_in_row != k:
+ # print(f"行 {i+1} 有 {stars_in_row} 颗星星,应该有 {k} 颗")
+ return False
+
+ # 2. 检查每列是否有k颗星星
for j in range(n):
- if star_grid[i][j] == 1:
- # 检查周围8个方向
- for di in [-1, 0, 1]:
- for dj in [-1, 0, 1]:
- if di == 0 and dj == 0:
- continue # 跳过自身
- ni, nj = i + di, j + dj
- if 0 <= ni < n and 0 <= nj < n and star_grid[ni][nj] == 1:
- # print(f"星星在 ({i},{j}) 与星星在 ({ni},{nj}) 相邻")
- return False
-
- # 所有检查通过
- # print("所有验证规则通过!")
- return True
+ stars_in_col = sum(star_grid[i][j] for i in range(n))
+ if stars_in_col != k:
+ # print(f"列 {j+1} 有 {stars_in_col} 颗星星,应该有 {k} 颗")
+ return False
+
+ # 3. 检查每个区域是否有k颗星星
+ regions = {}
+ for i in range(n):
+ for j in range(n):
+ region = region_grid[i][j]
+ if region not in regions:
+ regions[region] = []
+ regions[region].append((i, j))
+
+ for region, cells in regions.items():
+ stars_in_region = sum(star_grid[i][j] for i, j in cells)
+ if stars_in_region != k:
+ # print(f"区域 {region} 有 {stars_in_region} 颗星星,应该有 {k} 颗")
+ return False
+
+ # 4. 检查星星是否互不相邻(水平、垂直、对角线)
+ for i in range(n):
+ for j in range(n):
+ if star_grid[i][j] == 1:
+ # 检查周围8个方向
+ for di in [-1, 0, 1]:
+ for dj in [-1, 0, 1]:
+ if di == 0 and dj == 0:
+ continue # 跳过自身
+ ni, nj = i + di, j + dj
+ if 0 <= ni < n and 0 <= nj < n and star_grid[ni][nj] == 1:
+ # print(f"星星在 ({i},{j}) 与星星在 ({ni},{nj}) 相邻")
+ return False
+
+ # 所有检查通过
+ # print("所有验证规则通过!")
+ return True
+ return _verify_with_timeout()
except Exception as e:
print(f"Verification error (StarPlacementPuzzle): {e}")
diff --git a/verl/utils/reward_score/synlogic/time_sequence_verifier.py b/verl/utils/reward_score/synlogic/time_sequence_verifier.py
index 711a43d86..4b37ac641 100644
--- a/verl/utils/reward_score/synlogic/time_sequence_verifier.py
+++ b/verl/utils/reward_score/synlogic/time_sequence_verifier.py
@@ -17,35 +17,38 @@ def verify(self, data: Data, test_solution: str):
@return: 回答是否正确的布尔值
"""
try:
- test_answer = self.extract_answer(test_solution)
- # 解析元数据
- metadata = data.metadata
- true_answers = metadata['records']['answers']
-
- # 解析模型给出的列表
- try:
- test_list = json.loads(test_answer.replace(",", ","))
- except:
- print(f"NOTE!!! parse error!!!! (TimeSequence 1): {e}")
- return False
-
- try:
- if test_list[0]!=true_answers['answer_maxLen']:
- # print(f"最长会议时间不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]")
+ def _verify_with_timeout():
+ test_answer = self.extract_answer(test_solution)
+ # 解析元数据
+ metadata = data.metadata
+ true_answers = metadata["records"]["answers"]
+
+ # 解析模型给出的列表
+ try:
+ test_list = json.loads(test_answer.replace(",", ","))
+ except Exception as e:
+ # print(f"NOTE!!! parse error!!!! (TimeSequence 1): {e}")
return False
- if test_list[1]!=true_answers['answer_nums']:
- # print(f"可选会议数量不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]")
+
+ try:
+ if test_list[0] != true_answers["answer_maxLen"]:
+ # print(f"最长会议时间不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]")
+ return False
+ if test_list[1] != true_answers["answer_nums"]:
+ # print(f"可选会议数量不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]")
+ return False
+ except Exception as e:
+ # print(f"NOTE!!! parse error!!!! (TimeSequence 2): {e}")
return False
- except:
- print(f"NOTE!!! parse error!!!! (TimeSequence 2): {e}")
- return False
-
- # 所有检查都通过
- # print("验证结果: 正确")
- return True
+
+ # 所有检查都通过
+ # print("验证结果: 正确")
+ return True
+
+ return _verify_with_timeout()
except Exception as e:
print(f"Verification error (TimeSequence): {e}")
- return False
+ return False
def extract_answer(self, test_solution: str):
"""
diff --git a/verl/utils/reward_score/synlogic/web_of_lies_verifier.py b/verl/utils/reward_score/synlogic/web_of_lies_verifier.py
index 94be44fed..d0040661b 100644
--- a/verl/utils/reward_score/synlogic/web_of_lies_verifier.py
+++ b/verl/utils/reward_score/synlogic/web_of_lies_verifier.py
@@ -15,35 +15,36 @@ def verify(self, data: Data, test_solution: str):
@return: 回答是否正确的布尔值
"""
try:
- test_answer = self.extract_answer(test_solution)
- # 获取预期答案和测试答案
- expected_answer = data.answer.lower()
-
- # 清理测试答案
- test_answer = test_answer.lower()
-
- # 提取预期答案中的真假值
- expected_truths = self._parse_answer(expected_answer)
-
- # 提取测试答案中的真假值
- test_truths = self._parse_answer(test_answer)
-
- # print(f"验证: 预期答案={expected_truths}, 模型答案={test_truths}")
-
- # 检查答案列表长度是否匹配
- if len(expected_truths) != len(test_truths):
- # print(f"验证失败: 答案长度不匹配,预期 {len(expected_truths)},实际 {len(test_truths)}")
- return False
-
- # 检查每个位置的答案是否匹配
- for i, (expected, actual) in enumerate(zip(expected_truths, test_truths)):
- if expected != actual:
- # print(f"验证失败: 第 {i+1} 个答案不匹配,预期 {expected},实际 {actual}")
+ def _verify_with_timeout():
+ test_answer = self.extract_answer(test_solution)
+ # 获取预期答案和测试答案
+ expected_answer = data.answer.lower()
+
+ # 清理测试答案
+ test_answer = test_answer.lower()
+
+ # 提取预期答案中的真假值
+ expected_truths = self._parse_answer(expected_answer)
+
+ # 提取测试答案中的真假值
+ test_truths = self._parse_answer(test_answer)
+
+ # print(f"验证: 预期答案={expected_truths}, 模型答案={test_truths}")
+
+ # 检查答案列表长度是否匹配
+ if len(expected_truths) != len(test_truths):
+ # print(f"验证失败: 答案长度不匹配,预期 {len(expected_truths)},实际 {len(test_truths)}")
return False
-
- # print("验证成功: 所有答案匹配")
- return True
-
+
+ # 检查每个位置的答案是否匹配
+ for i, (expected, actual) in enumerate(zip(expected_truths, test_truths)):
+ if expected != actual:
+ # print(f"验证失败: 第 {i+1} 个答案不匹配,预期 {expected},实际 {actual}")
+ return False
+
+ # print("验证成功: 所有答案匹配")
+ return True
+ return _verify_with_timeout()
except Exception as e:
print(f"Verification error (WebOfLies): {e}")
return False
diff --git a/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py b/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py
index 8216f7605..aff680ce9 100644
--- a/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py
+++ b/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py
@@ -8,19 +8,21 @@ class WordSortingMistakeVerifier(Verifier):
"""
def verify(self, data: Data, test_answer: str):
try:
- ground_truth = data.answer if data.answer is not None else "No"
- parsed_answer = self.extract_answer(test_answer)
-
- if parsed_answer is None:
- return False
-
- if parsed_answer.isdigit():
- try:
- return int(parsed_answer) == int(ground_truth)
- except Exception as e:
+ def _verify_with_timeout():
+ ground_truth = data.answer if data.answer is not None else "No"
+ parsed_answer = self.extract_answer(test_answer)
+
+ if parsed_answer is None:
return False
- else:
- return parsed_answer.lower() == ground_truth.lower()
+
+ if parsed_answer.isdigit():
+ try:
+ return int(parsed_answer) == int(ground_truth)
+ except Exception as e:
+ return False
+ else:
+ return parsed_answer.lower() == str(ground_truth).lower()
+ return _verify_with_timeout()
except Exception as e:
print(f"NOTE!!! parse error!!!! (WordSortingMistake): {e}")
return False
diff --git a/verl/utils/reward_score/synlogic/word_sorting_verifier.py b/verl/utils/reward_score/synlogic/word_sorting_verifier.py
index 567581086..ed3ffe976 100644
--- a/verl/utils/reward_score/synlogic/word_sorting_verifier.py
+++ b/verl/utils/reward_score/synlogic/word_sorting_verifier.py
@@ -13,12 +13,14 @@ def str2list(self, answer_str):
def verify(self, data: Data, test_answer: str):
try:
- ground_truth = self.str2list(data.answer)
- parsed_answer = self.str2list(self.extract_answer(test_answer))
-
- if parsed_answer is None:
- return False
- return parsed_answer == ground_truth
+ def _verify_with_timeout():
+ ground_truth = self.str2list(data.answer)
+ parsed_answer = self.str2list(self.extract_answer(test_answer))
+
+ if parsed_answer is None:
+ return False
+ return parsed_answer == ground_truth
+ return _verify_with_timeout()
except Exception as e:
print(f"NOTE!!! parse error!!!! (WordSorting): {e}")
diff --git a/verl/utils/reward_score/synlogic/wordscapes_verifier.py b/verl/utils/reward_score/synlogic/wordscapes_verifier.py
index 2f30a7842..3e62888cb 100644
--- a/verl/utils/reward_score/synlogic/wordscapes_verifier.py
+++ b/verl/utils/reward_score/synlogic/wordscapes_verifier.py
@@ -8,152 +8,168 @@
debug_mode = False
+
class WordscapesVerifier(Verifier):
"""
Verifier for Wordscapes game
"""
+
def verify(self, data, test_solution: str):
"""
Verify whether the test answer is consistent with the gold answer
-
+
Args:
data: WordscapesData
test_solution: str containing the solution
-
+
Returns:
float: Score between 0 and 1
"""
try:
- extracted_answer = self.extract_answer(test_solution)
- if not extracted_answer:
- print("NOTE!!! parse error!!!! (Wordscapes): {e}")
- return False
-
- if debug_mode:
- for row in extracted_answer:
- print(" ".join(cell if cell != " " else "_" for cell in row))
-
- # Get grid, across_words, and down_words from data
- grid = data.metadata["grid"]
- across_words = data.metadata["across_words"]
- down_words = data.metadata["down_words"]
-
- # Validate grid dimensions
- if len(extracted_answer) != len(grid):
- # print(f"Grid height mismatch: expected {len(grid)}, got {len(extracted_answer)}")
- return False
-
- for i in range(len(grid)):
- if len(extracted_answer[i]) != len(grid[i]):
- # print(f"Grid width mismatch at row {i}: expected {len(grid[i])}, got {len(extracted_answer[i])}")
+
+ def _verify_with_timeout():
+ extracted_answer = self.extract_answer(test_solution)
+ if not extracted_answer:
+ # print("NOTE!!! parse error!!!! (Wordscapes): {e}")
return False
-
- # Check if the answer respects the grid layout (X for letters, 0 for empty)
- for i in range(len(grid)):
- for j in range(len(grid[i])):
- if grid[i][j] == "0" and extracted_answer[i][j].strip():
- # print(f"Expected empty space at position ({i},{j}), got '{extracted_answer[i][j]}'")
- return False
- if grid[i][j] == "X" and not extracted_answer[i][j].strip():
- # print(f"Expected letter at position ({i},{j}), got empty space")
+
+ if debug_mode:
+ for row in extracted_answer:
+ print(" ".join(cell if cell != " " else "_" for cell in row))
+
+ # Get grid, across_words, and down_words from data
+ grid = data.metadata["grid"]
+ across_words = data.metadata["across_words"]
+ down_words = data.metadata["down_words"]
+
+ # Validate grid dimensions
+ if len(extracted_answer) != len(grid):
+ # print(f"Grid height mismatch: expected {len(grid)}, got {len(extracted_answer)}")
+ return False
+
+ for i in range(len(grid)):
+ if len(extracted_answer[i]) != len(grid[i]):
+ # print(f"Grid width mismatch at row {i}: expected {len(grid[i])}, got {len(extracted_answer[i])}")
return False
-
- # Verify across words
- for word in across_words:
- found = False
- for i in range(len(extracted_answer)):
- row_str = ''.join(extracted_answer[i]).replace(' ', '').lower()
- if word.lower() in row_str:
- found = True
- break
- if not found and word:
- # print(f"Across word '{word}' not found in the grid")
- return 0
-
- # Verify down words
- for word in down_words:
- found = False
- for j in range(len(extracted_answer[0])):
- col = []
+
+ # Check if the answer respects the grid layout (X for letters, 0 for empty)
+ for i in range(len(grid)):
+ for j in range(len(grid[i])):
+ if grid[i][j] == "0" and extracted_answer[i][j].strip():
+ # print(f"Expected empty space at position ({i},{j}), got '{extracted_answer[i][j]}'")
+ return False
+ if grid[i][j] == "X" and not extracted_answer[i][j].strip():
+ # print(f"Expected letter at position ({i},{j}), got empty space")
+ return False
+
+ # Verify across words
+ for word_info in across_words:
+ found = False
+ target_word = word_info["word"] if isinstance(word_info, dict) else word_info
for i in range(len(extracted_answer)):
- if j < len(extracted_answer[i]):
- col.append(extracted_answer[i][j])
- col_str = ''.join(col).replace(' ', '').lower()
- if word.lower() in col_str:
- found = True
- break
- if not found and word: # Only check if word is not empty
- # print(f"Down word '{word}' not found in the grid")
- return False
-
- # All checks passed
- return True
+ row_str = "".join(extracted_answer[i]).replace(" ", "").lower()
+ if target_word.lower() in row_str:
+ found = True
+ break
+ if not found:
+ # print(f"Across word '{target_word}' not found in the grid")
+ return False
+
+ # Verify down words
+ for word_info in down_words:
+ found = False
+ target_word = word_info["word"] if isinstance(word_info, dict) else word_info
+ for j in range(len(extracted_answer[0])):
+ col = []
+ for i in range(len(extracted_answer)):
+ if j < len(extracted_answer[i]):
+ col.append(extracted_answer[i][j])
+ col_str = "".join(col).replace(" ", "").lower()
+ if target_word.lower() in col_str:
+ found = True
+ break
+ if not found:
+ # print(f"Down word '{target_word}' not found in the grid")
+ return False
+
+ # All checks passed
+ return True
+
+ return _verify_with_timeout()
except Exception as e:
print(f"Verification error (Wordscapes): {e}")
return False
-
+
def extract_answer(self, test_solution: str):
"""
Extract the answer from the test solution
-
+
Args:
test_solution: str
-
+
Returns:
list: 2D grid of the answer or None if extraction fails
"""
try:
# Remove thoughts if present
- if THOUGHT_DELIMITER_START in test_solution and THOUGHT_DELIMITER_END in test_solution:
+ if (
+ THOUGHT_DELIMITER_START in test_solution
+ and THOUGHT_DELIMITER_END in test_solution
+ ):
# Extract only the part after the thoughts
thought_end_pos = test_solution.rfind(THOUGHT_DELIMITER_END)
if thought_end_pos >= 0:
- test_solution = test_solution[thought_end_pos + len(THOUGHT_DELIMITER_END):]
-
+ test_solution = test_solution[
+ thought_end_pos + len(THOUGHT_DELIMITER_END) :
+ ]
+
# Clean up the response and find the grid pattern
# Look for a pattern like [[...]] or [[[...]]]
- grid_pattern = re.search(r'\[\s*\[(?:\s*\[)?(.+?)(?:\]\s*)?\]\s*\]', test_solution, re.DOTALL)
+ grid_pattern = re.search(
+ r"\[\s*\[(?:\s*\[)?(.+?)(?:\]\s*)?\]\s*\]", test_solution, re.DOTALL
+ )
if not grid_pattern:
return None
-
+
grid_text = grid_pattern.group(1)
-
+
# Handle various formats
rows = []
-
+
# Check if rows are separated by commas
- split_rows = re.split(r'\],\s*\[', grid_text)
-
+ split_rows = re.split(r"\],\s*\[", grid_text)
+
for row_text in split_rows:
# Clean the row text and extract characters
- row_text = row_text.strip().strip('[],')
-
+ row_text = row_text.strip().strip("[],")
+
# Extract quoted characters: "X" or 'X' or just X
chars = []
-
+
# Look for quoted strings or standalone characters
- char_matches = re.findall(r'\"([^\"]*)\"|\'([^\']*)\'|([^,\s]+)', row_text)
-
+ char_matches = re.findall(
+ r"\"([^\"]*)\"|\'([^\']*)\'|([^,\s]+)", row_text
+ )
+
for match in char_matches:
# Take the first non-empty group from each match
char = next((x for x in match if x), "")
-
+
# Handle numeric or empty values (0, "", '')
if char == "0" or char == "":
char = " "
-
+
chars.append(char)
-
+
if chars: # Only add non-empty rows
rows.append(chars)
-
+
# Make sure we have a valid grid
if not rows or not all(rows):
return None
-
+
return rows
-
+
except Exception as e:
print(f"NOTE!!! parse error!!!! (Wordscapes): {e}")
return None
-
\ No newline at end of file
diff --git a/verl/utils/reward_score/zebra_puzzle.py b/verl/utils/reward_score/zebra_puzzle.py
index 01da4d674..27aae6433 100644
--- a/verl/utils/reward_score/zebra_puzzle.py
+++ b/verl/utils/reward_score/zebra_puzzle.py
@@ -3,23 +3,6 @@
import ast
import operator
import json
-import signal
-import contextlib
-
-class TimeoutException(Exception):
- pass
-
-@contextlib.contextmanager
-def time_limit(seconds: float):
- def signal_handler(signum, frame):
- raise TimeoutException("Timed out!")
-
- signal.setitimer(signal.ITIMER_REAL, seconds)
- signal.signal(signal.SIGALRM, signal_handler)
- try:
- yield
- finally:
- signal.setitimer(signal.ITIMER_REAL, 0)
def extract_solution(solution_str):
@@ -68,22 +51,20 @@ def compute_accuracy(answer, ground_truth):
return accuracy
def compute_score(solution_str, ground_truth, extra_info: any = None, method='strict', timeout: float = 10.0):
- try:
- with time_limit(timeout):
- predicted_arrangement = extract_solution(solution_str)
+ def _compute_with_timeout():
+ predicted_arrangement = extract_solution(solution_str)
- if predicted_arrangement is None:
- score = 0.0
- else:
- try:
- accuracy = compute_accuracy(predicted_arrangement, ground_truth)
- score = accuracy
- except Exception as e:
- score = 0.0
-
- except TimeoutException:
- print("Computation timed out in zebra_puzzle")
- score = 0.0
+ if predicted_arrangement is None:
+ return 0.0
+ else:
+ try:
+ accuracy = compute_accuracy(predicted_arrangement, ground_truth)
+ return accuracy
+ except Exception as e:
+ return 0.0
+
+ try:
+ score = _compute_with_timeout()
except Exception as e:
print(f"Error in compute_score in zebra_puzzle: {e}")
score = 0.0
diff --git a/verl/workers/config/actor.py b/verl/workers/config/actor.py
index af6199732..011d2ebc5 100644
--- a/verl/workers/config/actor.py
+++ b/verl/workers/config/actor.py
@@ -49,6 +49,8 @@ class PolicyLossConfig(BaseConfig):
clip_cov_ub: float = 5.0
kl_cov_ratio: float = 0.0002
ppo_kl_coef: float = 0.1
+ cispo_clip_ratio_high: float = 0.2
+ cispo_clip_ratio_low: float = 0.2
@dataclass
@@ -225,6 +227,7 @@ class FSDPActorConfig(ActorConfig):
"""
strategy: str = "fsdp"
+ dtype: str = "bfloat16"
grad_clip: float = 1.0
ulysses_sequence_parallel_size: int = 1
entropy_from_logits_with_chunking: bool = False
diff --git a/verl/workers/reward_manager/async_mp.py b/verl/workers/reward_manager/async_mp.py
index 09b913c74..4577b11c7 100644
--- a/verl/workers/reward_manager/async_mp.py
+++ b/verl/workers/reward_manager/async_mp.py
@@ -13,6 +13,7 @@
# limitations under the License.
import asyncio
+import gc
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from functools import partial
@@ -116,6 +117,9 @@ async def parallel_compute_score_async(
actual_idx = start_idx + i
results[actual_idx] = result
+ # Force garbage collection after each batch to prevent memory accumulation
+ gc.collect()
+
except Exception:
for pid, proc in executor._processes.items():
try:
@@ -149,6 +153,7 @@ def __init__(
overlong_buffer_cfg=None,
batch_size=2048,
shuffle_batch=True,
+ num_processes=32,
**kwargs,
) -> None:
self.tokenizer = tokenizer
@@ -159,7 +164,7 @@ def __init__(
self.max_resp_len = max_resp_len
self.batch_size = batch_size
self.shuffle_batch = shuffle_batch
-
+ self.num_processes = num_processes
if self.overlong_buffer_cfg is not None:
assert self.max_resp_len is not None, (
f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None"
@@ -264,7 +269,7 @@ def __call__(self, data: DataProto, return_dict: bool = False):
solutions,
ground_truths,
extra_infos,
- num_processes=64,
+ num_processes=self.num_processes,
batch_size=self.batch_size,
shuffle=self.shuffle_batch,
)
@@ -328,6 +333,41 @@ def __call__(self, data: DataProto, return_dict: bool = False):
# print(f"[DEBUG] Non-zero elements in reward_tensor: {(reward_tensor != 0).sum().item()}")
# print(f"[DEBUG] Unique data sources processed: {list(already_print_data_sources.keys())}")
+ # Aggregate and print timing statistics by reward type
+ timing_by_type = defaultdict(lambda: {"count": 0, "total": 0.0, "max": 0.0})
+ for key_list in reward_extra_info.get("_reward_time", []), reward_extra_info.get("_data_source", []):
+ pass # Just to check if keys exist
+
+ if "_reward_time" in reward_extra_info and "_data_source" in reward_extra_info:
+ for reward_time, ds in zip(
+ reward_extra_info["_reward_time"], reward_extra_info["_data_source"], strict=False
+ ):
+ # Extract prefix (e.g., "codegen" from "codegen__deduped_leetcode2k")
+ if "__" in ds:
+ prefix = ds.split("__")[0]
+ elif "_" in ds:
+ prefix = ds.split("_")[0]
+ else:
+ prefix = ds
+ timing_by_type[prefix]["count"] += 1
+ timing_by_type[prefix]["total"] += reward_time
+ timing_by_type[prefix]["max"] = max(timing_by_type[prefix]["max"], reward_time)
+
+ # Print timing summary sorted by total time (descending)
+ print("\n=== REWARD TIMING BY TYPE ===")
+ for rtype, stats in sorted(timing_by_type.items(), key=lambda x: -x[1]["total"]):
+ avg = stats["total"] / max(stats["count"], 1)
+ print(
+ f" {rtype:20s}: {stats['count']:5d} samples, "
+ f"avg={avg*1000:8.2f}ms, max={stats['max']*1000:8.2f}ms, total={stats['total']:8.2f}s"
+ )
+ print("=" * 50 + "\n")
+
+ # Remove timing metadata keys before returning - they have inconsistent lengths
+ # when samples timeout/fail and would cause IndexError in batch reordering
+ reward_extra_info.pop("_reward_time", None)
+ reward_extra_info.pop("_data_source", None)
+
if return_dict:
return {
"reward_tensor": reward_tensor,
diff --git a/verl/workers/reward_manager/dapo.py b/verl/workers/reward_manager/dapo.py
index d8b6b4742..9d3aa6b09 100644
--- a/verl/workers/reward_manager/dapo.py
+++ b/verl/workers/reward_manager/dapo.py
@@ -109,6 +109,7 @@ def __call__(self, data: DataProto, return_dict: bool = False):
reward_extra_info[key].append(value)
else:
score = result
+ reward_extra_info["score"].append(score)
reward_extra_info["acc"].append(score)
reward = score