diff --git a/topicgpt_python/assignment.py b/topicgpt_python/assignment.py index 46373de..567d553 100644 --- a/topicgpt_python/assignment.py +++ b/topicgpt_python/assignment.py @@ -185,7 +185,18 @@ def assignment_batch( return responses, prompted_docs -def assign_topics(api, model, data, prompt_file, out_file, topic_file, verbose): +def assign_topics( + api, + model, + data, + prompt_file, + out_file, + topic_file, + verbose, + max_tokens=1000, + temperature=0.0, + top_p=1.0 +): """ Assign topics to a list of documents @@ -197,9 +208,11 @@ def assign_topics(api, model, data, prompt_file, out_file, topic_file, verbose): - out_file (str): Output file - topic_file (str): File to write topics to - verbose (bool): Whether to print out results + - max_tokens (int): Maximum number of tokens to generate (default: 1000) + - temperature (float): Sampling temperature (default: 0.0) + - top_p (float): Top-p sampling threshold (default: 1.0) """ api_client = APIClient(api=api, model=model) - max_tokens, temperature, top_p = 1000, 0.0, 1.0 if verbose: print("-------------------") @@ -301,6 +314,15 @@ def assign_topics(api, model, data, prompt_file, out_file, topic_file, verbose): parser.add_argument( "--verbose", type=bool, default=False, help="whether to print out results" ) + parser.add_argument( + "--max_tokens", type=int, default=1000, help="Maximum number of tokens to generate" + ) + parser.add_argument( + "--temperature", type=float, default=0.0, help="Sampling temperature" + ) + parser.add_argument( + "--top_p", type=float, default=1.0, help="Top-p sampling threshold" + ) args = parser.parse_args() assign_topics( @@ -311,4 +333,7 @@ def assign_topics(api, model, data, prompt_file, out_file, topic_file, verbose): args.out_file, args.topic_file, args.verbose, + args.max_tokens, + args.temperature, + args.top_p, ) diff --git a/topicgpt_python/correction.py b/topicgpt_python/correction.py index 00fd31b..657f74d 100644 --- a/topicgpt_python/correction.py +++ b/topicgpt_python/correction.py @@ -28,12 +28,11 @@ def topic_parser(root_topics, df, verbose=False): """ error, hallucinated = [], [] valid_topics = set(root_topics.get_root_descendants_name()) - topic_pattern = re.compile(r"\[\d\] [\w\s\-'\&]+") - strip_pattern = re.compile(r"^[^a-zA-Z]+|[^a-zA-Z]+$") + topic_pattern = re.compile(r"^\[\d+\] ([\w\s\-'\&]+):", re.MULTILINE) for i, response in enumerate(df.responses.tolist()): extracted_topics = [ - re.sub(strip_pattern, "", topic) + topic.strip() for topic in re.findall(topic_pattern, response) ] @@ -72,13 +71,14 @@ def correct( for i in tqdm(reprompt_idx, desc="Correcting topics"): doc = df.at[i, "prompted_docs"] + current_topics = all_topics if ( - api_client.estimate_token_count(doc + correction_prompt + all_topics) + api_client.estimate_token_count(doc + correction_prompt + current_topics) > context_len ): topic_embeddings = { topic: sbert.encode(topic, convert_to_tensor=True) - for topic in all_topics.split("\n") + for topic in current_topics.split("\n") } doc_embedding = sbert.encode(doc, convert_to_tensor=True) top_topics = sorted( @@ -93,10 +93,10 @@ def correct( and len(top_topics) > 50 ): top_topics.pop() - all_topics = "\n".join(top_topics) + current_topics = "\n".join(top_topics) max_doc_len = context_len - api_client.estimate_token_count( - correction_prompt + all_topics + correction_prompt + current_topics ) if api_client.estimate_token_count(doc) > max_doc_len: doc = api_client.truncate(doc, max_doc_len) @@ -104,7 +104,7 @@ def correct( try: msg = f"Previously, this document was assigned to: {df.at[i, 'responses']}. Please reassign it to an existing topic in the hierarchy." prompt = correction_prompt.format( - Document=doc, tree=all_topics, Message=msg + Document=doc, tree=current_topics, Message=msg ) result = api_client.iterative_prompt( prompt, max_tokens=max_tokens, temperature=temperature, top_p=top_p @@ -138,13 +138,14 @@ def correct_batch( for i in tqdm(reprompt_idx, desc="Correcting topics"): doc = df.at[i, "prompted_docs"] + current_topics = all_topics if ( - api_client.estimate_token_count(doc + correction_prompt + all_topics) + api_client.estimate_token_count(doc + correction_prompt + current_topics) > context_len ): topic_embeddings = { topic: sbert.encode(topic, convert_to_tensor=True) - for topic in all_topics.split("\n") + for topic in current_topics.split("\n") } doc_embedding = sbert.encode(doc, convert_to_tensor=True) top_topics = sorted( @@ -159,15 +160,15 @@ def correct_batch( and len(top_topics) > 50 ): top_topics.pop() - all_topics = "\n".join(top_topics) + current_topics = "\n".join(top_topics) max_doc_len = context_len - api_client.estimate_token_count( - correction_prompt + all_topics + correction_prompt + current_topics ) if api_client.estimate_token_count(doc) > max_doc_len: doc = api_client.truncate(doc, max_doc_len) msg = f"Previously, this document was assigned to: {df.at[i, 'responses']}. Please reassign it to an existing topic in the hierarchy." - prompt = correction_prompt.format(Document=doc, tree=all_topics, Message=msg) + prompt = correction_prompt.format(Document=doc, tree=current_topics, Message=msg) prompts.append(prompt) responses = api_client.batch_prompt( @@ -183,7 +184,16 @@ def correct_batch( def correct_topics( - api, model, data_path, prompt_path, topic_path, output_path, verbose=False + api, + model, + data_path, + prompt_path, + topic_path, + output_path, + verbose=False, + max_tokens=1000, + temperature=0.6, + top_p=0.9 ): """ Main function to parse, correct, and save topic assignments. @@ -196,9 +206,11 @@ def correct_topics( - topic_path: Path to topic file - output_path: Path to save corrected output - verbose: Print verbose output + - max_tokens (int): Maximum number of tokens to generate (default: 1000) + - temperature (float): Sampling temperature (default: 0.6) + - top_p (float): Top-p sampling threshold (default: 0.9) """ api_client = APIClient(api=api, model=model) - max_tokens, temperature, top_p = 1000, 0.6, 0.9 context_len = ( 128000 if model not in ["gpt-3.5-turbo", "gpt-4"] @@ -233,6 +245,7 @@ def correct_topics( reprompt_idx, temperature=temperature, top_p=top_p, + max_tokens=max_tokens, verbose=verbose, ) else: @@ -244,6 +257,7 @@ def correct_topics( context_len, reprompt_idx, verbose=verbose, + max_tokens=max_tokens, ) df.to_json(output_path, lines=True, orient="records") error, hallucinated = topic_parser(topics_root, df, verbose) @@ -294,6 +308,15 @@ def correct_topics( help="Path to save corrected output", ) parser.add_argument("--verbose", action="store_true", help="Enable verbose output") + parser.add_argument( + "--max_tokens", type=int, default=1000, help="Maximum number of tokens to generate" + ) + parser.add_argument( + "--temperature", type=float, default=0.6, help="Sampling temperature" + ) + parser.add_argument( + "--top_p", type=float, default=0.9, help="Top-p sampling threshold" + ) args = parser.parse_args() correct_topics( @@ -304,4 +327,7 @@ def correct_topics( args.topic_path, args.output_path, args.verbose, + args.max_tokens, + args.temperature, + args.top_p, ) diff --git a/topicgpt_python/generation_1.py b/topicgpt_python/generation_1.py index e81d8ba..3c4b174 100644 --- a/topicgpt_python/generation_1.py +++ b/topicgpt_python/generation_1.py @@ -93,7 +93,7 @@ def generate_topics( """ responses = [] running_dups = 0 - topic_format = regex.compile(r"^\[(\d+)\] ([\w\s]+):(.+)") + topic_format = regex.compile(r"^\[(\d+)\] ([\w\s\+_#-]+):(.+)") for i, doc in enumerate(tqdm(docs)): prompt = prompt_formatting( @@ -151,7 +151,18 @@ def generate_topics( def generate_topic_lvl1( - api, model, data, prompt_file, seed_file, out_file, topic_file, verbose + api, + model, + data, + prompt_file, + seed_file, + out_file, + topic_file, + verbose, + max_tokens=1000, + temperature=0.0, + top_p=1.0, + early_stop=1000 ): """ Generate high-level topics @@ -165,12 +176,14 @@ def generate_topic_lvl1( - out_file (str): File to write results to - topic_file (str): File to write topics to - verbose (bool): Whether to print out results + - max_tokens (int): Maximum number of tokens to generate (default: 1000) + - temperature (float): Sampling temperature (default: 0.0) + - top_p (float): Top-p sampling threshold (default: 1.0) Returns: - topics_root (TopicTree): Root node of the topic tree """ api_client = APIClient(api=api, model=model) - max_tokens, temperature, top_p = 1000, 0.0, 1.0 if verbose: print("-------------------") @@ -211,6 +224,7 @@ def generate_topic_lvl1( max_tokens, top_p, verbose, + early_stop=early_stop ) # Save generated topics @@ -273,6 +287,15 @@ def generate_topic_lvl1( parser.add_argument( "--verbose", type=bool, default=False, help="Whether to print out results" ) + parser.add_argument( + "--max_tokens", type=int, default=1000, help="Maximum number of tokens to generate" + ) + parser.add_argument( + "--temperature", type=float, default=0.0, help="Sampling temperature" + ) + parser.add_argument( + "--top_p", type=float, default=1.0, help="Top-p sampling threshold" + ) args = parser.parse_args() generate_topic_lvl1( args.api, @@ -283,4 +306,7 @@ def generate_topic_lvl1( args.out_file, args.topic_file, args.verbose, + args.max_tokens, + args.temperature, + args.top_p, ) diff --git a/topicgpt_python/generation_2.py b/topicgpt_python/generation_2.py index 69294a3..d215f23 100644 --- a/topicgpt_python/generation_2.py +++ b/topicgpt_python/generation_2.py @@ -48,7 +48,7 @@ def parse_document_topics(df, topics_list): Returns: List of topics for each document """ - pattern = regex.compile(r"^\[(\d+)\] ([\w\s]+):(.+)") + pattern = regex.compile(r"^\[(\d+)\] ([\w\s\+-_#]+):(.+)") all_topics = [] responses = ( @@ -236,7 +236,17 @@ def generate_topics( def generate_topic_lvl2( - api, model, seed_file, data, prompt_file, out_file, topic_file, verbose + api, + model, + seed_file, + data, + prompt_file, + out_file, + topic_file, + verbose, + max_tokens=1000, + temperature=0.0, + top_p=1.0 ): """ Generate subtopics for each top-level topic. @@ -250,11 +260,13 @@ def generate_topic_lvl2( - out_file: Output result file - topic_file: Output topics file - verbose: Enable verbose output + - max_tokens (int): Maximum number of tokens to generate (default: 1000) + - temperature (float): Sampling temperature (default: 0.0) + - top_p (float): Top-p sampling threshold (default: 1.0) Returns: Root node of the topic tree """ api_client = APIClient(api=api, model=model) - max_tokens, temperature, top_p = 1000, 0.0, 1.0 if verbose: print("-------------------") @@ -332,6 +344,15 @@ def generate_topic_lvl2( help="Output topics file", ) parser.add_argument("--verbose", action="store_true", help="Enable verbose output") + parser.add_argument( + "--max_tokens", type=int, default=1000, help="Maximum number of tokens to generate" + ) + parser.add_argument( + "--temperature", type=float, default=0.0, help="Sampling temperature" + ) + parser.add_argument( + "--top_p", type=float, default=1.0, help="Top-p sampling threshold" + ) args = parser.parse_args() generate_topic_lvl2( @@ -343,4 +364,7 @@ def generate_topic_lvl2( args.out_file, args.topic_file, args.verbose, + args.max_tokens, + args.temperature, + args.top_p, ) diff --git a/topicgpt_python/refinement.py b/topicgpt_python/refinement.py index 7681615..9a8ffda 100644 --- a/topicgpt_python/refinement.py +++ b/topicgpt_python/refinement.py @@ -91,9 +91,9 @@ def merge_topics( responses, orig_new = [], mapping pattern_topic = regex.compile( - r"^\[(\d+)\]([\w\s\-',]+)[^:]*:([\w\s,\.\-\/;']+) \(([^)]+)\)$" + r"^\[(\d+)\]([\w\s\-',-_\+#]+)[^:]*:([\w\s,\.\-\/;']+) \(([^)]+)\)$" ) - pattern_original = regex.compile(r"\[(\d+)\]([\w\s\-',]+),?") + pattern_original = regex.compile(r"\[(\d+)\]([\w\s\-',_\+#]+),?") while len(new_pairs) > 1: refiner_prompt = refinement_prompt.format(Topics="\n".join(new_pairs)) @@ -241,6 +241,9 @@ def refine_topics( verbose, remove, mapping_file, + max_tokens=1000, + temperature=0.0, + top_p=1.0, ): """ Main function to refine topics by merging and updating based on API response. @@ -256,12 +259,14 @@ def refine_topics( - verbose (bool): If True, prints each replacement made. - remove (bool): If True, removes low-frequency topics. - mapping_file (str): Path to save the mapping as a JSON file. + - max_tokens (int): Maximum number of tokens to generate (default: 1000) + - temperature (float): Sampling temperature (default: 0.0) + - top_p (float): Top-p sampling threshold (default: 1.0) Returns: - None """ api_client = APIClient(api=api, model=model) - max_tokens, temperature, top_p = 1000, 0.0, 1.0 topics_root = TopicTree().from_topic_list(topic_file, from_file=True) if verbose: print("-------------------") @@ -323,6 +328,15 @@ def refine_topics( parser.add_argument( "--mapping_file", type=str, default="data/output/refiner_mapping.json" ) + parser.add_argument( + "--max_tokens", type=int, default=1000, help="Maximum number of tokens to generate" + ) + parser.add_argument( + "--temperature", type=float, default=0.0, help="Sampling temperature" + ) + parser.add_argument( + "--top_p", type=float, default=1.0, help="Top-p sampling threshold" + ) args = parser.parse_args() refine_topics( @@ -335,5 +349,8 @@ def refine_topics( args.updated_file, args.verbose, args.remove, - args.mapping_file + args.mapping_file, + args.max_tokens, + args.temperature, + args.top_p, ) diff --git a/topicgpt_python/utils.py b/topicgpt_python/utils.py index 2712864..aa64934 100644 --- a/topicgpt_python/utils.py +++ b/topicgpt_python/utils.py @@ -156,13 +156,20 @@ def iterative_prompt( for attempt in range(num_try): try: if self.api in ["openai", "azure", "ollama"]: + completion_params = { + "model": self.model, + "messages": message, + "temperature": temperature, + "top_p": top_p, + } + completion_params[ + "max_completion_tokens" if self.api == "openai" else "max_tokens" + ] = max_tokens + completion = self.client.chat.completions.create( - model=self.model, - messages=message, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, + **completion_params ) + if verbose: print( "Prompt token usage:",