diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 0c5d38bc1f..1d703dfb57 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1180,24 +1180,38 @@ def sft_span_seach_mask_out( max_seq_length: int, asst_tag: str = "<|start_of_role|>assistant<|end_of_role|>", end_tag: str = "<|end_of_text|>", + think_tag: str = "\n\n", + mask_think_tag: bool = False, ignore_label: int = -100, ): """This function encodes a single example into a format that can be used for sft training (similar to sft_tulu_tokenize_and_truncate_v1). Instead of performing label masking iteratively, this function performs - masking via span search and can handle complex chat templates with thinking.""" - - # Span label masking strategy - # - search spans asst_tag ... end_tag - # - all such spans are left unmasked - # - if an asst_tage is undetected due to tokenization issues - # then a span can be erronously masked - # - to avoid this, use tags that are guarded by special tokens + masking via span search and can handle complex chat templates with thinking. + It dynamically determines the assistant tag based on the presence of a + block in the assistant's response. + """ + + # Dynamically determine the assistant tag based on the conversation content. + def has_thinking_content(messages): + for message in messages: + if message.get("role") == "assistant": + # Check for an explicit 'thought' field or a '' tag in the content. + if message.get("thought") or ( + isinstance(message.get("content"), str) and think_tag.strip() in message["content"] + ): + return True + return False + def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): # some prep match = lambda x, y: torch.all(x == y) - _asst_tag = tokenizer.encode(asst_tag) - _end_tag = tokenizer.encode(end_tag) + # `asst_tag` is captured from the outer scope's dynamic variable + # By default, tokenizers for models like LLaMA, Mistral, and Phi-2 have add_bos_token=True, which means they # automatically prepend a BOS token. This would shift all token positions by 1 and break any span matching logic. + # Other models like Qwen, etc., have add_bos_token=False by default, so they behave differently — leading to # inconsistent behavior across model families if not explicitly handled. + + _asst_tag = tokenizer.encode(asst_tag, add_special_tokens=False) + _end_tag = tokenizer.encode(end_tag, add_special_tokens=False) _asst_tag = torch.tensor([_asst_tag]) _end_tag = torch.tensor([_end_tag]) @@ -1223,7 +1237,7 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): else: raise ValueError( f"asst_tag has {num_tokens_asst} tokens, and end_tag has {num_tokens_end} " - "tokens, whereas the example has {num_tokens} tokens. Either " + f"tokens, whereas the example has {num_tokens} tokens. Either " "the example is invalid or wrong tags have been passed." ) @@ -1233,7 +1247,12 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): additional_inputs = {} for k in ["tools", "documents"]: if k in row: - additional_inputs[k] = row[k] + row_data = row[k] + try: + row_data = json.loads(row_data) + except (json.JSONDecodeError, TypeError) as e: + pass + additional_inputs[k] = row_data if len(messages) == 0: raise ValueError("messages field is empty.") @@ -1248,6 +1267,12 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): **additional_inputs, ) + # If the user has set `mask_think_tag=True` and the current sample is a thinking sample, + # then the token is appended to the base `asst_tag` used for span matching. + # This causes the tag to be masked along with the asst_tag + if mask_think_tag and has_thinking_content(messages): + asst_tag += think_tag + # Assume truncation if hitting the exact max length (for downstream data filtering) was_truncated = input_ids.shape[1] == max_seq_length row["was_truncated"] = was_truncated