-
Notifications
You must be signed in to change notification settings - Fork 18
refactor ppo callback to move its logic to single controller (part 1) #115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
110 commits
Select commit
Hold shift + click to select a range
8c1516c
runs but not limited master_port
bowenyang008 b0b57d9
hack local rank
bowenyang008 9e70dbb
composer launch works; torchrun somehow only allows use the same init…
bowenyang008 bc72f3e
clean up
bowenyang008 461c8c4
timeout to 30s
bowenyang008 e7803bb
update script
bowenyang008 feb62d1
None evn
bowenyang008 237373c
change internval
bowenyang008 4193ee3
try break and nodes instead
bowenyang008 fd63dcb
condition
bowenyang008 5c3f61d
rm resources
bowenyang008 ff227c4
sleep it
bowenyang008 172c676
use dist barrier to block
bowenyang008 f417e5b
context manager
bowenyang008 b29ba44
try to not release port
bowenyang008 9a30028
use ray actor
bowenyang008 12fbcef
another way to get ip address
bowenyang008 4ad2e65
ray init
bowenyang008 8df20ec
ray remote
bowenyang008 76c73a3
barrier
bowenyang008 8f51968
half subprocess
bowenyang008 a00369c
try w/o port
bowenyang008 a777c54
claude fix; questionable
bowenyang008 dd97cf2
do not raise
bowenyang008 09c92ed
revert back to subprocess
bowenyang008 cef2b9e
mix init
bowenyang008 9d93001
manual stop
bowenyang008 9dc7a74
two gpus runs
bowenyang008 c3db295
tensor parallel size 8
bowenyang008 824f086
update world size
bowenyang008 443d639
try all gpus for train again
bowenyang008 fde2960
use old port assignment
bowenyang008 8b625aa
get freeport again
bowenyang008 52bb8aa
change order again
bowenyang008 934ef7a
rank 0 again
bowenyang008 04217ab
half trainers
bowenyang008 f68c703
import; new method
bowenyang008 1fda4e3
no import
bowenyang008 6daa40c
import seems the killer
bowenyang008 42ab521
import vllm
bowenyang008 c86b2df
change rend time to 30s
bowenyang008 160d861
def method
bowenyang008 8eafdc6
add back vllm init
bowenyang008 0dc9f95
rm dup code
bowenyang008 d3c1a7c
comment out worker node
bowenyang008 8eb0116
running generations
bowenyang008 231674d
weight update donw
bowenyang008 d296b4e
not sync weight
bowenyang008 fe582a6
revert back
bowenyang008 f44f469
call master only
bowenyang008 01862b9
update pyproject for pyright filtering
bowenyang008 b7ebe7d
temp single controller
bowenyang008 1916fcd
separate out roundtrip code
bowenyang008 d8d4816
relocate to use test
bowenyang008 9d9850b
ref built
bowenyang008 8352a08
build ppo trainer
bowenyang008 4f093a4
weight update name mismatch; inference and prompts/gen exchange and t…
bowenyang008 88758c9
update ref model build
bowenyang008 1871b12
trains e2e!
bowenyang008 d2eb8f4
clean up imports
bowenyang008 567e37b
recover assert
bowenyang008 eb5b5c6
Merge remote-tracking branch 'origin/main' into boweny/playground
bowenyang008 f06dcfa
add test files
bowenyang008 133dac2
Merge branch 'main' into boweny/single-controller-composer
bowenyang008 da377bd
clean up dir
bowenyang008 3240715
refactor base actor
bowenyang008 861768b
rel method
bowenyang008 d0b5a44
callback fails
bowenyang008 ec7e550
revive ppo training again
bowenyang008 f7952b4
clean up
bowenyang008 95b96ef
first round code reduction done
bowenyang008 dee910b
another rm
bowenyang008 1e0b07a
train again; trim down
bowenyang008 b23b80e
recover weight update
bowenyang008 2f8c946
share method
bowenyang008 2966ccf
Revert "share method"
bowenyang008 fb98225
rm files
bowenyang008 be08a75
docs
bowenyang008 7dac060
doc
bowenyang008 d4a76e1
actor group
bowenyang008 24d2508
no more explicit master addr
bowenyang008 e37d2ec
all class ready
bowenyang008 7feb36c
run cmd
bowenyang008 c25bebc
use device direclty
bowenyang008 7195e4e
clean up to last method
bowenyang008 f28b8c8
clean up
bowenyang008 d180533
mv vllm engines out
bowenyang008 5c18c99
yeah works
bowenyang008 369e58e
relocate file and pytest works
bowenyang008 5a69a15
rm file
bowenyang008 5e2ebc2
format
bowenyang008 9e21f3f
Merge remote-tracking branch 'origin/main' into boweny/single-control…
bowenyang008 470bef9
format
bowenyang008 44b6609
type ignore
bowenyang008 0662810
todo
bowenyang008 ab4faad
different type fix
bowenyang008 7ff5867
todos
bowenyang008 092da43
doc fix
bowenyang008 6c5f216
format
bowenyang008 00f32f3
change gpu test to 2
bowenyang008 7094033
format
bowenyang008 59305a6
revert 4 gpu for now, regression does not like this test
bowenyang008 47d9b5f
revert 2 gpu for now
bowenyang008 9cb6e98
todo
bowenyang008 0b5d911
use 3.11 and update doc formatter
bowenyang008 e268626
revert change
bowenyang008 e9e9b3c
try diff
bowenyang008 6a513e4
yapf again
bowenyang008 81fce8c
todo
bowenyang008 eb889f1
revert
bowenyang008 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
58 changes: 58 additions & 0 deletions
58
compose_rl/algorithms/online/single_controller_callback.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| # Copyright 2024 MosaicML ComposeRL authors | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Online On-Policy RL callback.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| from typing import Union | ||
|
|
||
| from composer.core import State | ||
| from composer.loggers import Logger | ||
| from composer.trainer.trainer import _get_initial_device_train_microbatch_size | ||
| from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast | ||
|
|
||
| # Import the base class | ||
| from compose_rl.algorithms.online.callback import OnPolicyCallback | ||
| from compose_rl.algorithms.online.model import ( | ||
| ComposerHFPolicyLM, | ||
| ComposerMPTPolicyLM, | ||
| ) | ||
|
|
||
| Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] | ||
| Policy = Union[ComposerHFPolicyLM, ComposerMPTPolicyLM] | ||
|
|
||
| __all__ = ['SingleControllerOnPolicyCallback'] | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class SingleControllerOnPolicyCallback(OnPolicyCallback): | ||
| """Callback for managing on-policy training in an RLHF loop. | ||
|
|
||
| Ideally all the overwritten methods below should be implemented in the | ||
| trainer actor instead of the callback, we kept them here for now to minimize | ||
| a drastic refactor to PPO Callback code | ||
| """ | ||
|
|
||
| def iteration_start(self, state: State, logger: Logger): | ||
| del logger # unused | ||
|
|
||
| self._get_reward(self.batch_rollouts) # type: ignore | ||
|
|
||
| # Reset and initialize state train dataloader | ||
| log.warning( | ||
| 'trainer._train_data_spec should be updated whenever the dataloader is updated', | ||
| ) | ||
| # Train Dataloader | ||
| state.set_dataloader(self.buffer, 'ep') | ||
| state.train_dataloader = state.dataloader | ||
| state.device_train_microbatch_size = _get_initial_device_train_microbatch_size( | ||
| state.device_train_microbatch_size, | ||
| state.auto_microbatching, | ||
| state.train_dataloader, | ||
| ) | ||
|
|
||
| # Update IFT KL | ||
| self._update_ift_kl() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| # Copyright 2024 MosaicML ComposeRL authors | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import os | ||
| from datetime import timedelta | ||
| from typing import Optional | ||
|
|
||
| import ray | ||
| import torch.distributed as dist | ||
|
|
||
| from compose_rl.algorithms.online.generation_utils import init_process_group | ||
| from compose_rl.utils.ray_utils import ( | ||
| get_free_port, | ||
| get_node_ip, | ||
| is_cuda_visible_devices_set, | ||
| ) | ||
|
|
||
|
|
||
| class BaseDistributedGPUActor: | ||
|
|
||
| def __init__( | ||
| self, | ||
| rank: int, | ||
| world_size: int, | ||
| master_addr: Optional[str] = None, | ||
| master_port: Optional[int] = None, | ||
| ): | ||
| """Initialize the distributed GPU actor for RAY. | ||
|
|
||
| Args: | ||
| rank: The rank of this process in the distributed group | ||
| world_size: Total number of processes in the distributed group | ||
| master_addr: Master node address. If None, will allocate dynamically for rank 0 | ||
| master_port: Master node port. If None, will allocate dynamically for rank 0 | ||
| """ | ||
| self.rank = rank | ||
| self.world_size = world_size | ||
| self.master_addr = master_addr | ||
| self.master_port = master_port | ||
|
|
||
| # Set up basic environment variables | ||
| # TODO: may need to handle LOCAL_WORLD_SIZE as used in callback.py | ||
| os.environ['WORLD_SIZE'] = str(world_size) | ||
bowenyang008 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| os.environ['RANK'] = str(rank) | ||
|
|
||
| # Set LOCAL_RANK based on Ray GPU allocation | ||
| os.environ['LOCAL_RANK'] = '0' if is_cuda_visible_devices_set( | ||
| ) else str(ray.get_gpu_ids()[0]) | ||
|
|
||
| # If this is rank 0 and no master_addr/master_port provided, allocate them | ||
| if rank == 0 and (master_addr is None or master_port is None): | ||
| self._allocate_master_address() | ||
|
|
||
| os.environ['MASTER_ADDR'] = self.master_addr # type: ignore | ||
| os.environ['MASTER_PORT'] = str(self.master_port) # type: ignore | ||
|
|
||
| self.model = None | ||
| self.model_update_group = None | ||
|
|
||
| def _allocate_master_address(self): | ||
| """Allocate master address and port for rank 0.""" | ||
| if self.master_addr is None: | ||
| # Get the local IP address | ||
| self.master_addr = get_node_ip() | ||
|
|
||
| if self.master_port is None: | ||
| # Allocate a free port | ||
| self.master_port = get_free_port() | ||
|
|
||
| def get_master_address(self) -> tuple[Optional[str], Optional[int]]: | ||
| """Return the master address and port as a tuple.""" | ||
| return (self.master_addr, self.master_port) | ||
|
|
||
| def get_free_port(self): | ||
| return get_free_port() | ||
|
|
||
| def init_train_process_group(self): | ||
| """Initialize the distributed process group.""" | ||
| # Initialize process group | ||
| dist.init_process_group(timeout=timedelta(seconds=30)) | ||
|
|
||
| def add_process_group( | ||
| self, | ||
| backend: str, | ||
| master_addr: str, | ||
| master_port: int, | ||
| world_size: int, | ||
| rank: int, | ||
| group_name: str, | ||
| ): | ||
| """Initialize the process group on trainer rank 0 and vllm engines.""" | ||
| # NOTE vLLM seems to have a safer implementation of init_process_group: | ||
| # https://github.com/vllm-project/vllm/blob/v0.9.1/examples/offline_inference/rlhf.py#L105 | ||
| # we should look into using that instead | ||
| self.model_update_group = init_process_group( | ||
| backend=backend, | ||
| init_method=f'tcp://{master_addr}:{master_port}', | ||
| world_size=world_size, | ||
| rank=rank, | ||
| group_name=group_name, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.