diff --git a/examples/deepswe/train_deepswe_nb.py b/examples/deepswe/train_deepswe_nb.py index e847910fd..a36e48c78 100644 --- a/examples/deepswe/train_deepswe_nb.py +++ b/examples/deepswe/train_deepswe_nb.py @@ -185,6 +185,15 @@ "--loss_agg_mode", type=str, default="sequence-mean-token-scale" ) parser.add_argument("--advantage_estimator", type=str, default="rloo") +parser.add_argument( + "--degenerate_group_masking", + type=bool, + default=False, + help=( + "Whether to mask out groups whose advantages are all zero. " + "Default is False to align with rLLM DeepSWE." + ), +) # Other @@ -455,6 +464,7 @@ ) LOSS_AGG_MODE = args.loss_agg_mode ADVANTAGE_ESTIMATOR = args.advantage_estimator +DEGENERATE_GROUP_MASKING = args.degenerate_group_masking # %% @@ -778,6 +788,7 @@ def transform(entry): "filter_statuses": FILTER_STATUSES, "loss_agg_mode": LOSS_AGG_MODE, "advantage_estimator": ADVANTAGE_ESTIMATOR, + "degenerate_group_masking": DEGENERATE_GROUP_MASKING, } grpo_config = agentic_grpo_learner.GRPOConfig(**config_kwargs) @@ -856,6 +867,7 @@ def mixed_type_batch_fn(elements): "filter_statuses": ( [s.name for s in FILTER_STATUSES] if FILTER_STATUSES else None ), + "degenerate_group_masking": DEGENERATE_GROUP_MASKING, # Mesh topology "num_devices": len(devices), "rollout_mesh_fsdp": rollout_fsdp,