Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions topicgpt_python/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,18 @@ def assignment_batch(
return responses, prompted_docs


def assign_topics(api, model, data, prompt_file, out_file, topic_file, verbose):
def assign_topics(
api,
model,
data,
prompt_file,
out_file,
topic_file,
verbose,
max_tokens=1000,
temperature=0.0,
top_p=1.0
):
"""
Assign topics to a list of documents

Expand All @@ -197,9 +208,11 @@ def assign_topics(api, model, data, prompt_file, out_file, topic_file, verbose):
- out_file (str): Output file
- topic_file (str): File to write topics to
- verbose (bool): Whether to print out results
- max_tokens (int): Maximum number of tokens to generate (default: 1000)
- temperature (float): Sampling temperature (default: 0.0)
- top_p (float): Top-p sampling threshold (default: 1.0)
"""
api_client = APIClient(api=api, model=model)
max_tokens, temperature, top_p = 1000, 0.0, 1.0

if verbose:
print("-------------------")
Expand Down Expand Up @@ -301,6 +314,15 @@ def assign_topics(api, model, data, prompt_file, out_file, topic_file, verbose):
parser.add_argument(
"--verbose", type=bool, default=False, help="whether to print out results"
)
parser.add_argument(
"--max_tokens", type=int, default=1000, help="Maximum number of tokens to generate"
)
parser.add_argument(
"--temperature", type=float, default=0.0, help="Sampling temperature"
)
parser.add_argument(
"--top_p", type=float, default=1.0, help="Top-p sampling threshold"
)

args = parser.parse_args()
assign_topics(
Expand All @@ -311,4 +333,7 @@ def assign_topics(api, model, data, prompt_file, out_file, topic_file, verbose):
args.out_file,
args.topic_file,
args.verbose,
args.max_tokens,
args.temperature,
args.top_p,
)
56 changes: 41 additions & 15 deletions topicgpt_python/correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ def topic_parser(root_topics, df, verbose=False):
"""
error, hallucinated = [], []
valid_topics = set(root_topics.get_root_descendants_name())
topic_pattern = re.compile(r"\[\d\] [\w\s\-'\&]+")
strip_pattern = re.compile(r"^[^a-zA-Z]+|[^a-zA-Z]+$")
topic_pattern = re.compile(r"^\[\d+\] ([\w\s\-'\&]+):", re.MULTILINE)

for i, response in enumerate(df.responses.tolist()):
extracted_topics = [
re.sub(strip_pattern, "", topic)
topic.strip()
for topic in re.findall(topic_pattern, response)
]

Expand Down Expand Up @@ -72,13 +71,14 @@ def correct(

for i in tqdm(reprompt_idx, desc="Correcting topics"):
doc = df.at[i, "prompted_docs"]
current_topics = all_topics
if (
api_client.estimate_token_count(doc + correction_prompt + all_topics)
api_client.estimate_token_count(doc + correction_prompt + current_topics)
> context_len
):
topic_embeddings = {
topic: sbert.encode(topic, convert_to_tensor=True)
for topic in all_topics.split("\n")
for topic in current_topics.split("\n")
}
doc_embedding = sbert.encode(doc, convert_to_tensor=True)
top_topics = sorted(
Expand All @@ -93,18 +93,18 @@ def correct(
and len(top_topics) > 50
):
top_topics.pop()
all_topics = "\n".join(top_topics)
current_topics = "\n".join(top_topics)

max_doc_len = context_len - api_client.estimate_token_count(
correction_prompt + all_topics
correction_prompt + current_topics
)
if api_client.estimate_token_count(doc) > max_doc_len:
doc = api_client.truncate(doc, max_doc_len)

try:
msg = f"Previously, this document was assigned to: {df.at[i, 'responses']}. Please reassign it to an existing topic in the hierarchy."
prompt = correction_prompt.format(
Document=doc, tree=all_topics, Message=msg
Document=doc, tree=current_topics, Message=msg
)
result = api_client.iterative_prompt(
prompt, max_tokens=max_tokens, temperature=temperature, top_p=top_p
Expand Down Expand Up @@ -138,13 +138,14 @@ def correct_batch(

for i in tqdm(reprompt_idx, desc="Correcting topics"):
doc = df.at[i, "prompted_docs"]
current_topics = all_topics
if (
api_client.estimate_token_count(doc + correction_prompt + all_topics)
api_client.estimate_token_count(doc + correction_prompt + current_topics)
> context_len
):
topic_embeddings = {
topic: sbert.encode(topic, convert_to_tensor=True)
for topic in all_topics.split("\n")
for topic in current_topics.split("\n")
}
doc_embedding = sbert.encode(doc, convert_to_tensor=True)
top_topics = sorted(
Expand All @@ -159,15 +160,15 @@ def correct_batch(
and len(top_topics) > 50
):
top_topics.pop()
all_topics = "\n".join(top_topics)
current_topics = "\n".join(top_topics)

max_doc_len = context_len - api_client.estimate_token_count(
correction_prompt + all_topics
correction_prompt + current_topics
)
if api_client.estimate_token_count(doc) > max_doc_len:
doc = api_client.truncate(doc, max_doc_len)
msg = f"Previously, this document was assigned to: {df.at[i, 'responses']}. Please reassign it to an existing topic in the hierarchy."
prompt = correction_prompt.format(Document=doc, tree=all_topics, Message=msg)
prompt = correction_prompt.format(Document=doc, tree=current_topics, Message=msg)
prompts.append(prompt)

responses = api_client.batch_prompt(
Expand All @@ -183,7 +184,16 @@ def correct_batch(


def correct_topics(
api, model, data_path, prompt_path, topic_path, output_path, verbose=False
api,
model,
data_path,
prompt_path,
topic_path,
output_path,
verbose=False,
max_tokens=1000,
temperature=0.6,
top_p=0.9
):
"""
Main function to parse, correct, and save topic assignments.
Expand All @@ -196,9 +206,11 @@ def correct_topics(
- topic_path: Path to topic file
- output_path: Path to save corrected output
- verbose: Print verbose output
- max_tokens (int): Maximum number of tokens to generate (default: 1000)
- temperature (float): Sampling temperature (default: 0.6)
- top_p (float): Top-p sampling threshold (default: 0.9)
"""
api_client = APIClient(api=api, model=model)
max_tokens, temperature, top_p = 1000, 0.6, 0.9
context_len = (
128000
if model not in ["gpt-3.5-turbo", "gpt-4"]
Expand Down Expand Up @@ -233,6 +245,7 @@ def correct_topics(
reprompt_idx,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
verbose=verbose,
)
else:
Expand All @@ -244,6 +257,7 @@ def correct_topics(
context_len,
reprompt_idx,
verbose=verbose,
max_tokens=max_tokens,
)
df.to_json(output_path, lines=True, orient="records")
error, hallucinated = topic_parser(topics_root, df, verbose)
Expand Down Expand Up @@ -294,6 +308,15 @@ def correct_topics(
help="Path to save corrected output",
)
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
parser.add_argument(
"--max_tokens", type=int, default=1000, help="Maximum number of tokens to generate"
)
parser.add_argument(
"--temperature", type=float, default=0.6, help="Sampling temperature"
)
parser.add_argument(
"--top_p", type=float, default=0.9, help="Top-p sampling threshold"
)
args = parser.parse_args()

correct_topics(
Expand All @@ -304,4 +327,7 @@ def correct_topics(
args.topic_path,
args.output_path,
args.verbose,
args.max_tokens,
args.temperature,
args.top_p,
)
32 changes: 29 additions & 3 deletions topicgpt_python/generation_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def generate_topics(
"""
responses = []
running_dups = 0
topic_format = regex.compile(r"^\[(\d+)\] ([\w\s]+):(.+)")
topic_format = regex.compile(r"^\[(\d+)\] ([\w\s\+_#-]+):(.+)")

for i, doc in enumerate(tqdm(docs)):
prompt = prompt_formatting(
Expand Down Expand Up @@ -151,7 +151,18 @@ def generate_topics(


def generate_topic_lvl1(
api, model, data, prompt_file, seed_file, out_file, topic_file, verbose
api,
model,
data,
prompt_file,
seed_file,
out_file,
topic_file,
verbose,
max_tokens=1000,
temperature=0.0,
top_p=1.0,
early_stop=1000
):
"""
Generate high-level topics
Expand All @@ -165,12 +176,14 @@ def generate_topic_lvl1(
- out_file (str): File to write results to
- topic_file (str): File to write topics to
- verbose (bool): Whether to print out results
- max_tokens (int): Maximum number of tokens to generate (default: 1000)
- temperature (float): Sampling temperature (default: 0.0)
- top_p (float): Top-p sampling threshold (default: 1.0)

Returns:
- topics_root (TopicTree): Root node of the topic tree
"""
api_client = APIClient(api=api, model=model)
max_tokens, temperature, top_p = 1000, 0.0, 1.0

if verbose:
print("-------------------")
Expand Down Expand Up @@ -211,6 +224,7 @@ def generate_topic_lvl1(
max_tokens,
top_p,
verbose,
early_stop=early_stop
)

# Save generated topics
Expand Down Expand Up @@ -273,6 +287,15 @@ def generate_topic_lvl1(
parser.add_argument(
"--verbose", type=bool, default=False, help="Whether to print out results"
)
parser.add_argument(
"--max_tokens", type=int, default=1000, help="Maximum number of tokens to generate"
)
parser.add_argument(
"--temperature", type=float, default=0.0, help="Sampling temperature"
)
parser.add_argument(
"--top_p", type=float, default=1.0, help="Top-p sampling threshold"
)
args = parser.parse_args()
generate_topic_lvl1(
args.api,
Expand All @@ -283,4 +306,7 @@ def generate_topic_lvl1(
args.out_file,
args.topic_file,
args.verbose,
args.max_tokens,
args.temperature,
args.top_p,
)
30 changes: 27 additions & 3 deletions topicgpt_python/generation_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def parse_document_topics(df, topics_list):

Returns: List of topics for each document
"""
pattern = regex.compile(r"^\[(\d+)\] ([\w\s]+):(.+)")
pattern = regex.compile(r"^\[(\d+)\] ([\w\s\+-_#]+):(.+)")
all_topics = []

responses = (
Expand Down Expand Up @@ -236,7 +236,17 @@ def generate_topics(


def generate_topic_lvl2(
api, model, seed_file, data, prompt_file, out_file, topic_file, verbose
api,
model,
seed_file,
data,
prompt_file,
out_file,
topic_file,
verbose,
max_tokens=1000,
temperature=0.0,
top_p=1.0
):
"""
Generate subtopics for each top-level topic.
Expand All @@ -250,11 +260,13 @@ def generate_topic_lvl2(
- out_file: Output result file
- topic_file: Output topics file
- verbose: Enable verbose output
- max_tokens (int): Maximum number of tokens to generate (default: 1000)
- temperature (float): Sampling temperature (default: 0.0)
- top_p (float): Top-p sampling threshold (default: 1.0)

Returns: Root node of the topic tree
"""
api_client = APIClient(api=api, model=model)
max_tokens, temperature, top_p = 1000, 0.0, 1.0

if verbose:
print("-------------------")
Expand Down Expand Up @@ -332,6 +344,15 @@ def generate_topic_lvl2(
help="Output topics file",
)
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
parser.add_argument(
"--max_tokens", type=int, default=1000, help="Maximum number of tokens to generate"
)
parser.add_argument(
"--temperature", type=float, default=0.0, help="Sampling temperature"
)
parser.add_argument(
"--top_p", type=float, default=1.0, help="Top-p sampling threshold"
)
args = parser.parse_args()

generate_topic_lvl2(
Expand All @@ -343,4 +364,7 @@ def generate_topic_lvl2(
args.out_file,
args.topic_file,
args.verbose,
args.max_tokens,
args.temperature,
args.top_p,
)
Loading