diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 004b17e4..0361f651 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -284,6 +284,10 @@ def build_ppo_trainer(self): load_path=self.ref_path, ) + def close_trainer(self): + self.ppo_trainer.close() + + def add_rollouts(self, current_rank_rollouts: dict[str, Any]): """Adds the current rank's rollouts to the callback.""" for k, v in current_rank_rollouts.items(): @@ -474,6 +478,7 @@ def get_next_iter_rollouts(self): processed_sequences = torch.cat([all_prompts, padded_responses], dim=-1) iter_data['sequences'] = processed_sequences + return iter_data @@ -652,6 +657,9 @@ def train(self): # Populate the train actor group with the rollouts and then train self.train_actor.add_latest_rollouts_from_buffer(self.experience_buffer) self.train_actor.collective_methods.train_1_iter() + + self.train_actor.collective_methods.close_trainer() + def _run_single_controller_ppo(