-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
215 lines (189 loc) · 7.82 KB
/
main.py
File metadata and controls
215 lines (189 loc) · 7.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""Command-line entry point for the RedDebate framework.
This module parses CLI arguments, prints the resolved configuration, and
dispatches to :func:`redDebate.run.run_debate`. It supports configuring
debater models, optional roles (devil/angel, Socratic questioner, evaluator,
feedback generator, self-critic), the dataset(s) to iterate over, textual memory
backends (array-based long-term memory or a Pinecone vector store), and
PEFT/LoRA fine-tuning as continues long-term memory.
"""
from redDebate.run import run_debate
import argparse
def parse_debate_args():
"""Parse command-line arguments for a debate run.
Returns:
argparse.Namespace: parsed arguments. The ``--models`` and
``--datasets`` entries are lists of ``<type>:<name>[:use_chat]`` and
``<dataset_name>:<dataset_path>`` strings respectively (see argparse
``--help`` for the full per-argument descriptions).
"""
parser = argparse.ArgumentParser(description="Facilitate interaction between different types of LLMs.")
parser.add_argument(
"--models",
nargs='+',
type=str,
required=False,
default=list(),
help="List of debating models to use, in the format <type>:<model_name_or_path><use_chat>. For example, 'openai:gpt-4o-mini:true' or 'huggingface:/path/to/local/model:false'.",
)
parser.add_argument(
"--angel_model",
type=str,
required=False,
default=None,
help="Specify the angel model in the format <type>:<model_name_or_path><use_chat>. For example, 'openai:gpt-4o-mini:true' or 'huggingface:/path/to/local/model:false'.",
)
parser.add_argument(
"--devil_model",
type=str,
required=False,
default=None,
help="Specify the devil model in the format <type>:<model_name_or_path><use_chat>. For example, 'openai:gpt-4o-mini:true' or 'huggingface:/path/to/local/model:false'.",
)
parser.add_argument(
"--evaluator",
type=str,
required=False,
default=None,
help="Specify the evaluator model in the format <type>:<model_name_or_path><use_chat>. For example, 'openai:gpt-4o-mini:true' or 'huggingface:/path/to/local/model:false'. You can use openai:moderation to use OpenAI's moderation API.",
)
parser.add_argument(
"--feedback_generator",
type=str,
required=False,
default=None,
help="Specify the feedback generator model in the format <type>:<model_name_or_path><use_chat>. For example, 'openai:gpt-4o-mini:true' or 'huggingface:/path/to/local/model:false'.",
)
parser.add_argument(
"--questioner_model",
type=str,
required=False,
default=None,
help="Specify the questioner model in the format <type>:<model_name_or_path><use_chat>. For example, 'openai:gpt-4o-mini:true' or 'huggingface:/path/to/local/model:false'.",
)
parser.add_argument(
"--self_critique_model",
type=str,
required=False,
default=None,
help="Specify the self-critique model in the format <type>:<model_name_or_path><use_chat>. For example, 'openai:gpt-4o-mini:true' or 'huggingface:/path/to/local/model:false'.",
)
parser.add_argument(
"--datasets",
nargs='+',
type=str,
required=True,
help="List of datasets to use for each model interaction in the format <dataset_name>:<dataset_path>. For example, 'harmbench:/path/to/local/model'. Choose dataset name from: 'harmbench', 'cosafe', 'toxicchat', 'hhrlhf', or 'saferdialogues'.",
)
parser.add_argument(
"--debate_rounds",
type=int,
default=2,
help="Number of debate rounds between the models. Default is 2.",
)
parser.add_argument(
"--max_total_debates",
type=int,
default=None,
help="Maximum number of debates to run. If not set, there is no constraint.",
)
parser.add_argument(
"--output_file",
type=str,
default="debate.log",
help="File to log the debate. Default is 'debate.log'.",
)
parser.add_argument(
"--textual_memory_index",
type=str,
default=None,
help="Name of the index in vector database to store long-term memory. Default is None which initiate the array-based memory instead.",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default=None,
help="Directory to save and load checkpoints. Default is None.",
)
parser.add_argument(
"--llamaguard_cuda_device",
type=int,
default=0,
help="CUDA device index for LlamaGuard. Default is 0.",
)
parser.add_argument(
"--peft_memory",
action="store_true",
default=False,
help=(
"Enable PEFT/LoRA fine-tuning as a form of long-term memory. "
"The debating HuggingFace models are fine-tuned on accumulated feedback every "
"--train_steps debates. Long-term memory continues to work independently if a "
"--feedback_generator is provided. Default is disabled."
),
)
parser.add_argument(
"--peft_directory",
type=str,
default=None,
help=(
"Sub-directory (relative to HF_HUB_CACHE) where LoRA checkpoints are saved. "
"Required when --peft_memory is set. To run inference with a previously trained LoRA "
"model without further training, simply pass its merged checkpoint path in --models."
),
)
parser.add_argument(
"--train_steps",
type=int,
default=10,
help="Number of feedback samples to accumulate before triggering PEFT training. Default is 10.",
)
parser.add_argument(
"--human_in_the_loop",
type=int,
default=0,
help="Number of human participants to include in each debate round. Each human is prompted for input via stdin. Default is 0 (no humans).",
)
# Parse arguments
args = parser.parse_args()
print("Configurations:")
print(f"Debating Models: {args.models}")
print(f"Angel Model: {args.angel_model}")
print(f"Devil Model: {args.devil_model}")
print(f"Evaluator Model: {args.evaluator}")
print(f"Feedback Generator Model: {args.feedback_generator}")
print(f"Questioner Model: {args.questioner_model}")
print(f"Self Critique Model: {args.self_critique_model}")
print(f"Datasets: {args.datasets}")
print(f"Debate Rounds: {args.debate_rounds}")
print(f"Max Total Debates: {args.max_total_debates}")
print(f"Output File: {args.output_file}")
print(f"Textual Memory Index: {args.textual_memory_index}")
print(f"Checkpoint Directory: {args.checkpoint_dir}")
print(f"LlamaGuard CUDA Device: {args.llamaguard_cuda_device}")
print(f"PEFT Memory: {args.peft_memory}")
print(f"PEFT Directory: {args.peft_directory}")
print(f"PEFT Train Steps: {args.train_steps}")
print(f"Humans in the Loop: {args.human_in_the_loop}")
return args
if __name__ == "__main__":
args = parse_debate_args()
run_debate(
debater_models=args.models,
angel_model=args.angel_model,
devil_model=args.devil_model,
evaluator_model=args.evaluator,
feedback_generator=args.feedback_generator,
questioner_model=args.questioner_model,
self_critique_model=args.self_critique_model,
datasets=args.datasets,
debate_rounds=args.debate_rounds,
max_total_debates=args.max_total_debates,
output_file=args.output_file,
textual_memory_index=args.textual_memory_index,
llamaguard_cuda_device=args.llamaguard_cuda_device,
checkpoint_dir=args.checkpoint_dir,
peft_memory=args.peft_memory,
peft_directory=args.peft_directory,
train_steps=args.train_steps,
human_in_the_loop=args.human_in_the_loop,
)