-
Notifications
You must be signed in to change notification settings - Fork 17
Adding Tool-N1 data set to training mix with sync rl #154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jb3618columbia
wants to merge
1
commit into
verl-latest-cispo
Choose a base branch
from
single_turn_tool_calling_data
base: verl-latest-cispo
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+718
−0
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
234 changes: 234 additions & 0 deletions
234
scripts/train/tool_n1_test_multinode_rl_qwen2.5_32b_base_fsdp.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,234 @@ | ||
| #!/bin/bash | ||
| #SBATCH --job-name=tool-n1-multinode-rl-qwen2.5-7b-base-fsdp | ||
| #SBATCH --nodes=2 | ||
| #SBATCH --ntasks=2 | ||
| #SBATCH --ntasks-per-node=1 | ||
| #SBATCH --gres=gpu:8 | ||
| #SBATCH --cpus-per-task=128 | ||
| #SBATCH --mem=0 | ||
| #SBATCH --output=slurm/%x-%j.out | ||
| #SBATCH --error=slurm/%x-%j.err | ||
| #SBATCH --exclusive | ||
| #SBATCH --time=720:00:00 | ||
| #SBATCH --partition=main | ||
| #SBATCH --account=iq | ||
|
|
||
|
|
||
| # =================== Frequently Used Variables =================== | ||
| RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch | ||
| export STEM_LLM_JUDGE_URL="<STEM_LLM_JUDGE_URL>" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain | ||
jb3618columbia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # =================== Cluster Environment =================== | ||
| export CONDA_BIN_PATH=/mnt/weka/home/jalaj.bhandari/miniconda3/envs/jalaj_sync_rl/bin/ | ||
jb3618columbia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| export NCCL_TIMEOUT_SECONDS=14400 # Increased to 4 hours for 2-node stability during checkpoint saves | ||
| export TORCH_NCCL_ENABLE_MONITORING=0 | ||
| export NCCL_DEBUG=info | ||
| export NCCL_NET=IB | ||
| export NCCL_IB_HCA="mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7" | ||
| export NCCL_CROSS_NIC=1 | ||
| export NCCL_IB_TC=136 | ||
| export NCCL_SOCKET_IFNAME="^lo,docker,virbr" | ||
| export CUDA_DEVICE_MAX_CONNECTIONS=8 | ||
| export NCCL_NVLS_ENABLE=1 | ||
| export NCCL_ASYNC_ERROR_HANDLING=1 # Handle NCCL errors gracefully | ||
| export TORCH_NCCL_BLOCKING_WAIT=0 # Non-blocking NCCL wait | ||
|
|
||
| # Get the list of allocated nodes | ||
| nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) | ||
| echo "Nodes to check: ${nodes[@]}" | ||
|
|
||
| # We'll track PIDs so we can wait on them and detect errors | ||
| declare -A pids | ||
| export head_node=${nodes[0]} | ||
| head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) | ||
| port=6379 | ||
| address_head=$head_node_ip:$port | ||
|
|
||
| export worker_num=$SLURM_NNODES | ||
| export HYDRA_FULL_ERROR=1 | ||
| export VLLM_USE_V1=1 | ||
|
|
||
| # =================== Data Mixture =================== | ||
| TRAIN_DATA_DIR=/mnt/weka/shrd/k2tls/jalaj | ||
| TEST_DATA_DIR=/mnt/weka/shrd/k2tls/jalaj | ||
jb3618columbia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| tool_n1_train_path=${TRAIN_DATA_DIR}/tool-n1_train.parquet | ||
| tool_n1_test_path=${TEST_DATA_DIR}/tool-n1_test.parquet | ||
|
|
||
| train_files="['${tool_n1_train_path}']" | ||
| test_files="['${tool_n1_test_path}']" | ||
jb3618columbia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # =================== Model =================== | ||
| BASE_MODEL=Qwen/Qwen2.5-7B | ||
|
|
||
| # =================== Logging =================== | ||
| WANDB_PROJECT=Reasoning360 | ||
| WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} | ||
|
|
||
| # If RESUME_CKPT_DIR is not empty, resume from the checkpoint | ||
| if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then | ||
| WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" | ||
| fi | ||
|
|
||
|
|
||
| # =================== Ray start =================== | ||
| # ray stop at all nodes | ||
| srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop | ||
|
|
||
| sleep 10 | ||
| # Remove existing Ray cluster | ||
| srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster | ||
|
|
||
| # Start Ray head node | ||
| srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ | ||
| ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ | ||
| --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & | ||
|
|
||
| sleep 10 | ||
|
|
||
| # Start Ray worker nodes | ||
| for ((i = 1; i < worker_num; i++)); do | ||
| node_i=${nodes[$i]} | ||
| echo "Starting WORKER $i at $node_i" | ||
| srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ | ||
| ${CONDA_BIN_PATH}ray start --address "$address_head" \ | ||
| --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & | ||
| done | ||
| sleep 10 | ||
|
|
||
|
|
||
| # =================== RL Config =================== | ||
| # Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. | ||
|
|
||
| adv_estimator=grpo | ||
|
|
||
| use_kl_in_reward=False | ||
| kl_coef=0.0 | ||
| use_kl_loss=False | ||
| kl_loss_coef=0.0 | ||
|
|
||
| clip_ratio_low=0.2 | ||
| clip_ratio_high=0.2 | ||
|
|
||
| max_prompt_length=$((1024 * 4)) | ||
| max_response_length=$((1024 * 8)) | ||
| enable_overlong_buffer=False | ||
| overlong_buffer_len=$((1024 * 4)) | ||
| overlong_penalty_factor=1.0 | ||
|
|
||
| loss_agg_mode="token-mean" | ||
|
|
||
| enable_filter_groups=False | ||
| filter_groups_metric=acc | ||
| max_num_gen_batches=10 | ||
| train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n | ||
| gen_prompt_bsz=$((train_prompt_bsz * 1)) | ||
| n_resp_per_prompt=16 | ||
| train_prompt_mini_bsz=64 # model grad update batchsize | ||
|
|
||
| # Algorithm | ||
| temperature=1.0 | ||
| top_p=1.0 | ||
| top_k=-1 # 0 for HF rollout, -1 for vLLM rollout | ||
|
|
||
| # Training config | ||
| sp_size=1 | ||
| gen_tp=2 | ||
| gen_max_num_seqs=1024 | ||
| infer_micro_batch_size=null | ||
| train_micro_batch_size=null | ||
| use_dynamic_bsz=True | ||
| actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up model forward & backward but note memory overflow | ||
| infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up modelforward, but note memory overflow | ||
| offload=True | ||
|
|
||
| # =================== Start RL training =================== | ||
| "${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ | ||
| --config-path=config \ | ||
| --config-name="dapo_fsdp_config.yaml" \ | ||
| algorithm.adv_estimator=${adv_estimator} \ | ||
| algorithm.use_kl_in_reward=${use_kl_in_reward} \ | ||
| algorithm.kl_ctrl.kl_coef=${kl_coef} \ | ||
| algorithm.filter_groups.enable=${enable_filter_groups} \ | ||
| algorithm.filter_groups.metric=${filter_groups_metric} \ | ||
| algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ | ||
| data.train_files="$train_files" \ | ||
| data.val_files="$test_files" \ | ||
| data.prompt_key=prompt \ | ||
| data.truncation='right' \ | ||
| data.max_prompt_length=${max_prompt_length} \ | ||
| data.max_response_length=${max_response_length} \ | ||
| data.train_batch_size=${train_prompt_bsz} \ | ||
| data.gen_batch_size=${gen_prompt_bsz} \ | ||
| actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ | ||
| actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ | ||
| actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ | ||
| actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ | ||
| actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ | ||
| actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ | ||
| actor_rollout_ref.actor.clip_ratio_c=10.0 \ | ||
| actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ | ||
| actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ | ||
| actor_rollout_ref.actor.strategy="fsdp" \ | ||
| actor_rollout_ref.actor.optim.lr=1e-6 \ | ||
| actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ | ||
| actor_rollout_ref.actor.optim.weight_decay=0.1 \ | ||
| actor_rollout_ref.actor.optim.warmup_style=constant \ | ||
| actor_rollout_ref.actor.optim.min_lr_ratio=0. \ | ||
| actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ | ||
| actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ | ||
| actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ | ||
| actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ | ||
| actor_rollout_ref.actor.entropy_coeff=0 \ | ||
| actor_rollout_ref.actor.grad_clip=1.0 \ | ||
| actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ | ||
| actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ | ||
| actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ | ||
| actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ | ||
| actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ | ||
| actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ | ||
| actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ | ||
| actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ | ||
| actor_rollout_ref.rollout.name=vllm \ | ||
| actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ | ||
| actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ | ||
| actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ | ||
| actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ | ||
| actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ | ||
| actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ | ||
| actor_rollout_ref.rollout.enable_chunked_prefill=True \ | ||
| actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ | ||
| actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ | ||
| actor_rollout_ref.rollout.temperature=${temperature} \ | ||
| actor_rollout_ref.rollout.top_p=${top_p} \ | ||
| actor_rollout_ref.rollout.top_k=${top_k} \ | ||
| actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ | ||
| actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ | ||
| actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ | ||
| actor_rollout_ref.rollout.val_kwargs.n=1 \ | ||
| actor_rollout_ref.rollout.val_kwargs.do_sample=True \ | ||
| actor_rollout_ref.model.path=$BASE_MODEL \ | ||
| actor_rollout_ref.model.use_remove_padding=True \ | ||
| actor_rollout_ref.rollout.multi_turn.enable=False \ | ||
| actor_rollout_ref.rollout.mode="sync" \ | ||
| +actor_rollout_ref.model.override_config.attention_dropout=0. \ | ||
| +actor_rollout_ref.model.override_config.embd_pdrop=0. \ | ||
| +actor_rollout_ref.model.override_config.resid_pdrop=0. \ | ||
| actor_rollout_ref.model.enable_gradient_checkpointing=True \ | ||
| reward_model.reward_manager=async_multi_process \ | ||
| reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ | ||
| reward_model.overlong_buffer.len=${overlong_buffer_len} \ | ||
| reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ | ||
| trainer.logger=['console','wandb'] \ | ||
| trainer.project_name=${WANDB_PROJECT} \ | ||
| trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ | ||
| trainer.val_before_train=False \ | ||
| trainer.n_gpus_per_node=8 \ | ||
| trainer.nnodes=$worker_num \ | ||
| trainer.save_freq=50 \ | ||
| trainer.test_freq=100 \ | ||
| trainer.total_epochs=1 \ | ||
| trainer.log_val_generations=50 \ | ||
| trainer.resume_mode=auto \ | ||
| trainer.max_actor_ckpt_to_keep=2 \ | ||
| actor_rollout_ref.actor.checkpoint.async_save=False | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| ## added by reasoning360 | ||
jb3618columbia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| import json | ||
| from verl.utils.reward_score.toolcall import compute_score_v0 | ||
|
|
||
|
|
||
| class TestComputeScoreV0: | ||
| """Unit tests for compute_score_v0 function""" | ||
|
|
||
| def test_correct_solution_with_thinking(self): | ||
| """Test: Correct solution with thinking tags should return 1""" | ||
| ground_truth = json.dumps( | ||
| [{"name": "calculator", "arguments": {"expression": "2+2"}}] | ||
| ) | ||
| solution_str = """<|im_start|>assistant | ||
| <think> | ||
| I need to calculate 2+2 which equals 4. | ||
| </think> | ||
| <tool_call> | ||
| [{"name": "calculator", "arguments": {"expression": "2+2"}}] | ||
| </tool_call>""" | ||
|
|
||
| score = compute_score_v0(solution_str, ground_truth) | ||
| assert score == 1 | ||
|
|
||
| def test_correct_solution_without_thinking(self): | ||
| """Test: Correct solution without thinking tags should return 0""" | ||
| ground_truth = json.dumps( | ||
| [{"name": "calculator", "arguments": {"expression": "2+2"}}] | ||
| ) | ||
| solution_str = """<|im_start|>assistant | ||
| <tool_call> | ||
| [{"name": "calculator", "arguments": {"expression": "2+2"}}] | ||
| </tool_call>""" | ||
|
|
||
| score = compute_score_v0(solution_str, ground_truth) | ||
| assert score == 0 | ||
|
|
||
| def test_malformed_json_in_tool_call(self): | ||
| """Test: Malformed JSON in tool call should return 0""" | ||
| ground_truth = json.dumps( | ||
| [{"name": "calculator", "arguments": {"expression": "2+2"}}] | ||
| ) | ||
| solution_str = """<|im_start|>assistant | ||
| <think> | ||
| I need to calculate something. | ||
| </think> | ||
| <tool_call> | ||
| [{"name": "calculator", "arguments": {"expression": "2+2"} | ||
| </tool_call>""" | ||
|
|
||
| score = compute_score_v0(solution_str, ground_truth) | ||
| assert score == 0 | ||
|
|
||
|
|
||
| def test_wrong_tool_name(self): | ||
| """Test: Wrong tool name should return 0""" | ||
| ground_truth = json.dumps( | ||
| [{"name": "calculator", "arguments": {"expression": "2+2"}}] | ||
| ) | ||
| solution_str = """<|im_start|>assistant | ||
| <think> | ||
| I need to calculate 2+2. | ||
| </think> | ||
| <tool_call> | ||
| [{"name": "wrong_tool", "arguments": {"expression": "2+2"}}] | ||
| </tool_call>""" | ||
|
|
||
| score = compute_score_v0(solution_str, ground_truth) | ||
| assert score == 0 | ||
|
|
||
| def test_wrong_arguments(self): | ||
| """Test: Correct tool name but wrong arguments should return 0""" | ||
| ground_truth = json.dumps( | ||
| [{"name": "calculator", "arguments": {"expression": "2+2"}}] | ||
| ) | ||
| solution_str = """<|im_start|>assistant | ||
| <think> | ||
| I need to calculate 2+3. | ||
| </think> | ||
| <tool_call> | ||
| [{"name": "calculator", "arguments": {"expression": "2+3"}}] | ||
| </tool_call>""" | ||
|
|
||
| score = compute_score_v0(solution_str, ground_truth) | ||
| assert score == 0 | ||
|
|
||
| def test_missing_required_fields(self): | ||
| """Test: Missing required fields (name or arguments) should return 0""" | ||
| ground_truth = json.dumps( | ||
| [{"name": "calculator", "arguments": {"expression": "2+2"}}] | ||
| ) | ||
| solution_str = """<|im_start|>assistant | ||
| <think> | ||
| I need to calculate 2+2. | ||
| </think> | ||
| <tool_call> | ||
| [{"name": "calculator"}] | ||
| </tool_call>""" | ||
|
|
||
| score = compute_score_v0(solution_str, ground_truth) | ||
| assert score == 0 | ||
|
|
||
| def test_multiple_tool_calls(self): | ||
| """Test: Multiple tool calls that match ground truth""" | ||
| ground_truth = json.dumps( | ||
| [ | ||
| {"name": "calculator", "arguments": {"expression": "2+2"}}, | ||
| {"name": "calculator", "arguments": {"expression": "3*4"}}, | ||
| ] | ||
| ) | ||
| solution_str = """<|im_start|>assistant | ||
| <think> | ||
| I need to perform two calculations. | ||
| </think> | ||
| <tool_call> | ||
| [ | ||
| {"name": "calculator", "arguments": {"expression": "2+2"}}, | ||
| {"name": "calculator", "arguments": {"expression": "3*4"}} | ||
| ] | ||
| </tool_call>""" | ||
|
|
||
| score = compute_score_v0(solution_str, ground_truth) | ||
| assert score == 1 | ||
|
|
||
| def test_no_tool_call_tag(self): | ||
| """Test: Missing tool_call tags should return 0 (no extraction)""" | ||
| ground_truth = json.dumps( | ||
| [{"name": "calculator", "arguments": {"expression": "2+2"}}] | ||
| ) | ||
| solution_str = """<|im_start|>assistant | ||
| <think> | ||
| I need to calculate 2+2. | ||
| </think> | ||
| The result is 4.""" | ||
|
|
||
| score = compute_score_v0(solution_str, ground_truth) | ||
| assert score == 0 | ||
|
|
||
jb3618columbia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove this file from the commit