From ce1b1f16d535e1b48636698abee51517241c8bf0 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Thu, 14 Aug 2025 15:15:58 -0400 Subject: [PATCH 01/24] Add special tokens argument added --- open_instruct/dataset_transformation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 30fe3a9deb..8a97e3e475 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1196,8 +1196,8 @@ def sft_span_seach_mask_out( def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): # some prep match = lambda x, y: torch.all(x == y) - _asst_tag = tokenizer.encode(asst_tag) - _end_tag = tokenizer.encode(end_tag) + _asst_tag = tokenizer.encode(asst_tag, add_special_tokens=False) + _end_tag = tokenizer.encode(end_tag, add_special_tokens=False) _asst_tag = torch.tensor([_asst_tag]) _end_tag = torch.tensor([_end_tag]) From 0b796e001b059ee05413e7788a38851c4d8387a7 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Fri, 22 Aug 2025 13:22:56 -0400 Subject: [PATCH 02/24] V7 tools string type change --- open_instruct/dataset_transformation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 92c2c22b13..c5be762173 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1233,7 +1233,11 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): additional_inputs = {} for k in ["tools", "documents"]: if k in row: - additional_inputs[k] = row[k] + if k == "tools": + if isinstance(row[k], str) and len(row[k]) > 0: + additional_inputs[k] = json.loads(row[k]) + else: + additional_inputs[k] = row[k] if len(messages) == 0: raise ValueError("messages field is empty.") From 56b8a3ad4024dd8c72a7ea962a507d76eec32042 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Tue, 26 Aug 2025 11:58:34 -0400 Subject: [PATCH 03/24] Changed asst_tag for think --- open_instruct/dataset_transformation.py | 48 +++++++++++++++---------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index c5be762173..5268ac5514 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1178,24 +1178,43 @@ def sft_span_seach_mask_out( row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int, - asst_tag: str = "<|start_of_role|>assistant<|end_of_role|>", end_tag: str = "<|end_of_text|>", ignore_label: int = -100, ): """This function encodes a single example into a format that can be used for sft training (similar to sft_tulu_tokenize_and_truncate_v1). Instead of performing label masking iteratively, this function performs - masking via span search and can handle complex chat templates with thinking.""" - - # Span label masking strategy - # - search spans asst_tag ... end_tag - # - all such spans are left unmasked - # - if an asst_tage is undetected due to tokenization issues - # then a span can be erronously masked - # - to avoid this, use tags that are guarded by special tokens + masking via span search and can handle complex chat templates with thinking. + It dynamically determines the assistant tag based on the presence of a + block in the assistant's response. + """ + + messages = row["messages"] + if len(messages) == 0: + raise ValueError("messages field is empty.") + + # Dynamically determine the assistant tag based on the conversation content. + is_think_sample = False + for message in messages: + if message.get("role") == "assistant": + # Check for an explicit 'thought' field or a '' tag in the content. + if message.get("thought") or ( + isinstance(message.get("content"), str) and "" in message["content"] + ): + is_think_sample = True + break # A single 'think' block defines the sample type. + + # Setting the appropriate assistant tag for the masking strategy. + asst_tag = ( + "<|start_of_role|>assistant<|end_of_role|>\n\n" + if is_think_sample + else "<|start_of_role|>assistant<|end_of_role|>" + ) + def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): # some prep match = lambda x, y: torch.all(x == y) + # `asst_tag` is now captured from the outer scope's dynamic variable _asst_tag = tokenizer.encode(asst_tag, add_special_tokens=False) _end_tag = tokenizer.encode(end_tag, add_special_tokens=False) _asst_tag = torch.tensor([_asst_tag]) @@ -1223,24 +1242,17 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): else: raise ValueError( f"asst_tag has {num_tokens_asst} tokens, and end_tag has {num_tokens_end} " - "tokens, whereas the example has {num_tokens} tokens. Either " + f"tokens, whereas the example has {num_tokens} tokens. Either " "the example is invalid or wrong tags have been passed." ) return labels - messages = row["messages"] additional_inputs = {} for k in ["tools", "documents"]: if k in row: - if k == "tools": - if isinstance(row[k], str) and len(row[k]) > 0: - additional_inputs[k] = json.loads(row[k]) - else: - additional_inputs[k] = row[k] + additional_inputs[k] = row[k] - if len(messages) == 0: - raise ValueError("messages field is empty.") input_ids = tokenizer.apply_chat_template( conversation=messages, tokenize=True, From 25d4929f9b2abdcbde7e4786240366bbcc112033 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Tue, 26 Aug 2025 12:49:39 -0400 Subject: [PATCH 04/24] Changed sft_span fn wrt reviews --- open_instruct/dataset_transformation.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 5268ac5514..26fbd3086b 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1178,6 +1178,7 @@ def sft_span_seach_mask_out( row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int, + asst_tag: str = "<|start_of_role|>assistant<|end_of_role|>", end_tag: str = "<|end_of_text|>", ignore_label: int = -100, ): @@ -1206,9 +1207,9 @@ def sft_span_seach_mask_out( # Setting the appropriate assistant tag for the masking strategy. asst_tag = ( - "<|start_of_role|>assistant<|end_of_role|>\n\n" + asst_tag + "\n\n" if is_think_sample - else "<|start_of_role|>assistant<|end_of_role|>" + else asst_tag ) def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): @@ -1251,8 +1252,12 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): additional_inputs = {} for k in ["tools", "documents"]: if k in row: - additional_inputs[k] = row[k] - + if k == "tools": + if isinstance(row[k], str) and len(row[k]) > 0: + additional_inputs[k] = json.loads(row[k]) + else: + additional_inputs[k] = row[k] + input_ids = tokenizer.apply_chat_template( conversation=messages, tokenize=True, From 3e42f9ee49ef9af60364461e22e77d6086dfd759 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Tue, 26 Aug 2025 13:07:15 -0400 Subject: [PATCH 05/24] function Changes per reviews --- open_instruct/dataset_transformation.py | 41 +++++++++++-------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 26fbd3086b..b559fe1d01 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1189,29 +1189,19 @@ def sft_span_seach_mask_out( It dynamically determines the assistant tag based on the presence of a block in the assistant's response. """ - - messages = row["messages"] - if len(messages) == 0: - raise ValueError("messages field is empty.") # Dynamically determine the assistant tag based on the conversation content. - is_think_sample = False - for message in messages: - if message.get("role") == "assistant": - # Check for an explicit 'thought' field or a '' tag in the content. - if message.get("thought") or ( - isinstance(message.get("content"), str) and "" in message["content"] - ): - is_think_sample = True - break # A single 'think' block defines the sample type. - - # Setting the appropriate assistant tag for the masking strategy. - asst_tag = ( - asst_tag + "\n\n" - if is_think_sample - else asst_tag - ) - + def is_think(messages): + for message in messages: + if message.get("role") == "assistant": + # Check for an explicit 'thought' field or a '' tag in the content. + if message.get("thought") or ( + isinstance(message.get("content"), str) and "" in message["content"] + ): + return True + + return False + def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): # some prep match = lambda x, y: torch.all(x == y) @@ -1249,6 +1239,7 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): return labels + messages = row["messages"] additional_inputs = {} for k in ["tools", "documents"]: if k in row: @@ -1256,8 +1247,10 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): if isinstance(row[k], str) and len(row[k]) > 0: additional_inputs[k] = json.loads(row[k]) else: - additional_inputs[k] = row[k] - + additional_inputs[k] = row[k] + + if len(messages) == 0: + raise ValueError("messages field is empty.") input_ids = tokenizer.apply_chat_template( conversation=messages, tokenize=True, @@ -1268,6 +1261,8 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): add_generation_prompt=False, **additional_inputs, ) + + asst_tag = asst_tag + "\n\n" if is_think else asst_tag # Assume truncation if hitting the exact max length (for downstream data filtering) was_truncated = input_ids.shape[1] == max_seq_length From ae2c8605dd05409317cfcfff73af59da8644795b Mon Sep 17 00:00:00 2001 From: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Date: Tue, 26 Aug 2025 13:20:53 -0400 Subject: [PATCH 06/24] Update open_instruct/dataset_transformation.py Co-authored-by: Yu Chin Fabian Lim --- open_instruct/dataset_transformation.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index b559fe1d01..a0105d0157 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1243,11 +1243,13 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): additional_inputs = {} for k in ["tools", "documents"]: if k in row: - if k == "tools": - if isinstance(row[k], str) and len(row[k]) > 0: - additional_inputs[k] = json.loads(row[k]) - else: - additional_inputs[k] = row[k] + row_data = row[k] + try: + row_data = json.loads(row_data) + except json.decoder.JsonDecodeError: + pass + + additional_inputs[k] = row_data if len(messages) == 0: raise ValueError("messages field is empty.") From 09392712ccfb9abcbfb598e6e213ef40b3e1c8c6 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Tue, 26 Aug 2025 13:22:26 -0400 Subject: [PATCH 07/24] function Changes per reviews --- open_instruct/dataset_transformation.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index a0105d0157..562efb02fe 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1245,11 +1245,12 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): if k in row: row_data = row[k] try: - row_data = json.loads(row_data) - except json.decoder.JsonDecodeError: - pass - - additional_inputs[k] = row_data + if k == "tools" and isinstance(row_data, str) and row_data: + additional_inputs[k] = json.loads(row_data) + else: + additional_inputs[k] = row_data + except (json.JSONDecodeError, TypeError) as e: + print(f"Failed to parse '{k}': {e}") if len(messages) == 0: raise ValueError("messages field is empty.") From 9cfa6b2eab75f2299298fc93fc948fd78759c9de Mon Sep 17 00:00:00 2001 From: divykum2 Date: Tue, 26 Aug 2025 13:34:29 -0400 Subject: [PATCH 08/24] ruff changes --- open_instruct/dataset_transformation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 562efb02fe..186b3be627 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1199,9 +1199,8 @@ def is_think(messages): isinstance(message.get("content"), str) and "" in message["content"] ): return True - return False - + def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): # some prep match = lambda x, y: torch.all(x == y) @@ -1264,7 +1263,7 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): add_generation_prompt=False, **additional_inputs, ) - + asst_tag = asst_tag + "\n\n" if is_think else asst_tag # Assume truncation if hitting the exact max length (for downstream data filtering) From 3e74717e5b58ddba385ebc0fc8158a3a91206d69 Mon Sep 17 00:00:00 2001 From: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Date: Tue, 26 Aug 2025 13:36:33 -0400 Subject: [PATCH 09/24] Update dataset_transformation.py --- open_instruct/dataset_transformation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 186b3be627..37eceeb924 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1244,7 +1244,7 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): if k in row: row_data = row[k] try: - if k == "tools" and isinstance(row_data, str) and row_data: + if k == "tools" and isinstance(row_data, str) and len(row_data) > 0: additional_inputs[k] = json.loads(row_data) else: additional_inputs[k] = row_data From 2ad362af4ec6b935ca01a5c85bc3cd7bf3ca70f5 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Tue, 26 Aug 2025 14:10:58 -0400 Subject: [PATCH 10/24] ruff changes --- open_instruct/dataset_transformation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 186b3be627..15f0e8dc90 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1264,7 +1264,7 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): **additional_inputs, ) - asst_tag = asst_tag + "\n\n" if is_think else asst_tag + asst_tag = asst_tag + "\n\n" if is_think(messages) else asst_tag # Assume truncation if hitting the exact max length (for downstream data filtering) was_truncated = input_ids.shape[1] == max_seq_length From 41d59c4e723e8a6e439dfe71bc2a54dbdf64f0e1 Mon Sep 17 00:00:00 2001 From: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Date: Tue, 26 Aug 2025 14:38:06 -0400 Subject: [PATCH 11/24] Added try/catch for json loads Co-authored-by: Yu Chin Fabian Lim --- open_instruct/dataset_transformation.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index e7f36c3077..7793225a7c 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1244,12 +1244,10 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): if k in row: row_data = row[k] try: - if k == "tools" and isinstance(row_data, str) and len(row_data) > 0: - additional_inputs[k] = json.loads(row_data) - else: - additional_inputs[k] = row_data + row_data = json.loads(row_data) except (json.JSONDecodeError, TypeError) as e: - print(f"Failed to parse '{k}': {e}") + pass + additional_inputs[k] = row_data if len(messages) == 0: raise ValueError("messages field is empty.") From 17d9dba9f84012886098661c75af435687609b2f Mon Sep 17 00:00:00 2001 From: divykum2 Date: Tue, 26 Aug 2025 15:14:49 -0400 Subject: [PATCH 12/24] Added check_sample flag --- open_instruct/dataset_transformation.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index e7f36c3077..8f81c4147a 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1180,6 +1180,7 @@ def sft_span_seach_mask_out( max_seq_length: int, asst_tag: str = "<|start_of_role|>assistant<|end_of_role|>", end_tag: str = "<|end_of_text|>", + check_sample: bool = False, ignore_label: int = -100, ): """This function encodes a single example into a format that @@ -1264,8 +1265,10 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): **additional_inputs, ) - asst_tag = asst_tag + "\n\n" if is_think(messages) else asst_tag - + if check_sample: + if is_think(messages): + asst_tag += "\n\n" + # Assume truncation if hitting the exact max length (for downstream data filtering) was_truncated = input_ids.shape[1] == max_seq_length row["was_truncated"] = was_truncated From 7e6df5d646a64380b9d0272a5186f5b957b619d6 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Tue, 26 Aug 2025 15:18:26 -0400 Subject: [PATCH 13/24] Added check_sample flag --- open_instruct/dataset_transformation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 1829c400a2..080b18722a 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1266,7 +1266,7 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): if check_sample: if is_think(messages): asst_tag += "\n\n" - + # Assume truncation if hitting the exact max length (for downstream data filtering) was_truncated = input_ids.shape[1] == max_seq_length row["was_truncated"] = was_truncated From 45fa9a8ccd4813cce2b1ccf1a1d28a0c4b276c11 Mon Sep 17 00:00:00 2001 From: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Date: Wed, 27 Aug 2025 14:23:43 -0400 Subject: [PATCH 14/24] Update open_instruct/dataset_transformation.py Co-authored-by: Yu Chin Fabian Lim --- open_instruct/dataset_transformation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 080b18722a..c957eb2e85 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1186,9 +1186,9 @@ def sft_span_seach_mask_out( """This function encodes a single example into a format that can be used for sft training (similar to sft_tulu_tokenize_and_truncate_v1). Instead of performing label masking iteratively, this function performs - masking via span search and can handle complex chat templates with thinking. - It dynamically determines the assistant tag based on the presence of a - block in the assistant's response. + masking via span search and can handle complex chat templates with thinking. If the think_tag + is to be excluded from the masking, please provide it appropriately and also set mask_think_tag=False + """ """ # Dynamically determine the assistant tag based on the conversation content. From ea32413277bf479497c3739c64c241e572b5b65c Mon Sep 17 00:00:00 2001 From: divykum2 Date: Wed, 27 Aug 2025 14:24:26 -0400 Subject: [PATCH 15/24] Added think tag and changed masking flag --- open_instruct/dataset_transformation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 080b18722a..80b4291c67 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1180,7 +1180,8 @@ def sft_span_seach_mask_out( max_seq_length: int, asst_tag: str = "<|start_of_role|>assistant<|end_of_role|>", end_tag: str = "<|end_of_text|>", - check_sample: bool = False, + think_tag: str = "\n\n", + mask_think_tag: bool = False, ignore_label: int = -100, ): """This function encodes a single example into a format that @@ -1263,9 +1264,9 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): **additional_inputs, ) - if check_sample: + if mask_think_tag: if is_think(messages): - asst_tag += "\n\n" + asst_tag += think_tag # Assume truncation if hitting the exact max length (for downstream data filtering) was_truncated = input_ids.shape[1] == max_seq_length From adcb1d23a285f7eb4e023cfc52d25865d755a026 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Wed, 27 Aug 2025 14:51:43 -0400 Subject: [PATCH 16/24] Added think tag and changed masking flag --- open_instruct/dataset_transformation.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 36be3efbc6..62b5936f6f 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1187,9 +1187,9 @@ def sft_span_seach_mask_out( """This function encodes a single example into a format that can be used for sft training (similar to sft_tulu_tokenize_and_truncate_v1). Instead of performing label masking iteratively, this function performs - masking via span search and can handle complex chat templates with thinking. If the think_tag - is to be excluded from the masking, please provide it appropriately and also set mask_think_tag=False - """ + masking via span search and can handle complex chat templates with thinking. + It dynamically determines the assistant tag based on the presence of a + block in the assistant's response. """ # Dynamically determine the assistant tag based on the conversation content. @@ -1206,7 +1206,12 @@ def is_think(messages): def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): # some prep match = lambda x, y: torch.all(x == y) - # `asst_tag` is now captured from the outer scope's dynamic variable + # `asst_tag` is captured from the outer scope's dynamic variable + + # By default, tokenizers for models like LLaMA, Mistral, and Phi-2 have add_bos_token=True, which means they # automatically prepend a BOS token. This would shift all token positions by 1 and break any span matching logic. + + # Other models like Qwen, etc., have add_bos_token=False by default, so they behave differently — leading to # inconsistent behavior across model families if not explicitly handled. + _asst_tag = tokenizer.encode(asst_tag, add_special_tokens=False) _end_tag = tokenizer.encode(end_tag, add_special_tokens=False) _asst_tag = torch.tensor([_asst_tag]) @@ -1264,9 +1269,10 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): **additional_inputs, ) - if mask_think_tag: - if is_think(messages): - asst_tag += think_tag + if mask_think_tag and is_think(messages): + # if think tag is not to be masked, then it is to be included in the asst_tag + # which by token matching logic, will cause it to be omitted from the mask span. + asst_tag += think_tag # Assume truncation if hitting the exact max length (for downstream data filtering) was_truncated = input_ids.shape[1] == max_seq_length From 23a70f9c0a40a22ce2dbc7e820e1ed6ee1cbaa6a Mon Sep 17 00:00:00 2001 From: divykum2 Date: Wed, 27 Aug 2025 14:54:47 -0400 Subject: [PATCH 17/24] Added think tag and changed masking flag --- open_instruct/dataset_transformation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 62b5936f6f..664f1ecd28 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1270,8 +1270,8 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): ) if mask_think_tag and is_think(messages): - # if think tag is not to be masked, then it is to be included in the asst_tag - # which by token matching logic, will cause it to be omitted from the mask span. + # if think tag is to be masked and the message is a thinking sample, + # then the think token is to be included in the asst_tag asst_tag += think_tag # Assume truncation if hitting the exact max length (for downstream data filtering) From fdb1131e46e8a6b1b9792592ad0187b733d7f4b2 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Wed, 27 Aug 2025 14:57:14 -0400 Subject: [PATCH 18/24] ruff checks --- open_instruct/dataset_transformation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 664f1ecd28..410b0a6f07 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1270,7 +1270,7 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): ) if mask_think_tag and is_think(messages): - # if think tag is to be masked and the message is a thinking sample, + # if think tag is to be masked and the message is a thinking sample, # then the think token is to be included in the asst_tag asst_tag += think_tag From f1b5bde12744d3b7a7d74b8cf9bfd820fef3a41b Mon Sep 17 00:00:00 2001 From: divykum2 Date: Wed, 24 Sep 2025 09:36:24 -0400 Subject: [PATCH 19/24] renamed variables and added more description --- open_instruct/dataset_transformation.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 410b0a6f07..e7be9db0bb 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1178,10 +1178,10 @@ def sft_span_seach_mask_out( row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int, - asst_tag: str = "<|start_of_role|>assistant<|end_of_role|>", + asst_tag: str = "<|start_of_role|>assistant<|end_of_role|>\n\n", end_tag: str = "<|end_of_text|>", think_tag: str = "\n\n", - mask_think_tag: bool = False, + append_think_tag: bool = False, ignore_label: int = -100, ): """This function encodes a single example into a format that @@ -1193,7 +1193,7 @@ def sft_span_seach_mask_out( """ # Dynamically determine the assistant tag based on the conversation content. - def is_think(messages): + def has_thinking_content(messages): for message in messages: if message.get("role") == "assistant": # Check for an explicit 'thought' field or a '' tag in the content. @@ -1207,9 +1207,7 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): # some prep match = lambda x, y: torch.all(x == y) # `asst_tag` is captured from the outer scope's dynamic variable - # By default, tokenizers for models like LLaMA, Mistral, and Phi-2 have add_bos_token=True, which means they # automatically prepend a BOS token. This would shift all token positions by 1 and break any span matching logic. - # Other models like Qwen, etc., have add_bos_token=False by default, so they behave differently — leading to # inconsistent behavior across model families if not explicitly handled. _asst_tag = tokenizer.encode(asst_tag, add_special_tokens=False) @@ -1269,9 +1267,9 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): **additional_inputs, ) - if mask_think_tag and is_think(messages): - # if think tag is to be masked and the message is a thinking sample, - # then the think token is to be included in the asst_tag + if append_think_tag and has_thinking_content(messages): + # if user has specified appendind think tag to base asst tag and the specific row's message column is a thinking sample, + # then token is to be included in the asst_tag asst_tag += think_tag # Assume truncation if hitting the exact max length (for downstream data filtering) From ae2b00e8d5584aeacb2c677c60554ec3e0c147df Mon Sep 17 00:00:00 2001 From: divykum2 Date: Wed, 24 Sep 2025 09:39:56 -0400 Subject: [PATCH 20/24] renamed variables and added more description --- open_instruct/dataset_transformation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index e7be9db0bb..4273c43de2 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1267,10 +1267,11 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): **additional_inputs, ) - if append_think_tag and has_thinking_content(messages): - # if user has specified appendind think tag to base asst tag and the specific row's message column is a thinking sample, - # then token is to be included in the asst_tag - asst_tag += think_tag + # if user has specified appendind think tag to base asst tag and the specific row's message column is a thinking sample, + # then token is to be included in the asst_tag + if append_think_tag: + if has_thinking_content(messages): + asst_tag += think_tag # Assume truncation if hitting the exact max length (for downstream data filtering) was_truncated = input_ids.shape[1] == max_seq_length From 99c10437c3fecb9bb408bc5d1dc52fe031c621eb Mon Sep 17 00:00:00 2001 From: divykum2 Date: Wed, 24 Sep 2025 10:44:36 -0400 Subject: [PATCH 21/24] Chages after Fabian's review --- open_instruct/dataset_transformation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 4273c43de2..260d4f146d 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1178,7 +1178,7 @@ def sft_span_seach_mask_out( row: Dict[str, Any], tokenizer: PreTrainedTokenizer, max_seq_length: int, - asst_tag: str = "<|start_of_role|>assistant<|end_of_role|>\n\n", + asst_tag: str = "<|start_of_role|>assistant<|end_of_role|>", end_tag: str = "<|end_of_text|>", think_tag: str = "\n\n", append_think_tag: bool = False, @@ -1198,7 +1198,7 @@ def has_thinking_content(messages): if message.get("role") == "assistant": # Check for an explicit 'thought' field or a '' tag in the content. if message.get("thought") or ( - isinstance(message.get("content"), str) and "" in message["content"] + isinstance(message.get("content"), str) and think_tag.strip() in message["content"] ): return True return False From 988927dd527a8413fb80427b0ff8cd07346cba71 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Wed, 24 Sep 2025 11:22:25 -0400 Subject: [PATCH 22/24] Chages after Fabian's review --- open_instruct/dataset_transformation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 260d4f146d..f97bedd35d 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1267,8 +1267,9 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): **additional_inputs, ) - # if user has specified appendind think tag to base asst tag and the specific row's message column is a thinking sample, - # then token is to be included in the asst_tag + # If the user has set `append_think_tag=True` and the current sample is a thinking sample, + # then the token is appended to the base `asst_tag` used for span matching. + # This causes the tag to be masked along with the asst_tag if append_think_tag: if has_thinking_content(messages): asst_tag += think_tag From 78c5947d5d15f08d9c083321100b539a3706cebe Mon Sep 17 00:00:00 2001 From: divykum2 Date: Wed, 24 Sep 2025 11:33:47 -0400 Subject: [PATCH 23/24] ruff checks --- open_instruct/dataset_transformation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index f97bedd35d..4714b2ef64 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1269,7 +1269,7 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): # If the user has set `append_think_tag=True` and the current sample is a thinking sample, # then the token is appended to the base `asst_tag` used for span matching. - # This causes the tag to be masked along with the asst_tag + # This causes the tag to be masked along with the asst_tag if append_think_tag: if has_thinking_content(messages): asst_tag += think_tag From 280025d6e628f4b9759de26239d7806fe5b1e394 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Wed, 24 Sep 2025 12:55:50 -0400 Subject: [PATCH 24/24] Name changes --- open_instruct/dataset_transformation.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 4714b2ef64..1d703dfb57 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1181,7 +1181,7 @@ def sft_span_seach_mask_out( asst_tag: str = "<|start_of_role|>assistant<|end_of_role|>", end_tag: str = "<|end_of_text|>", think_tag: str = "\n\n", - append_think_tag: bool = False, + mask_think_tag: bool = False, ignore_label: int = -100, ): """This function encodes a single example into a format that @@ -1267,12 +1267,11 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): **additional_inputs, ) - # If the user has set `append_think_tag=True` and the current sample is a thinking sample, + # If the user has set `mask_think_tag=True` and the current sample is a thinking sample, # then the token is appended to the base `asst_tag` used for span matching. # This causes the tag to be masked along with the asst_tag - if append_think_tag: - if has_thinking_content(messages): - asst_tag += think_tag + if mask_think_tag and has_thinking_content(messages): + asst_tag += think_tag # Assume truncation if hitting the exact max length (for downstream data filtering) was_truncated = input_ids.shape[1] == max_seq_length