Skip to content
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ce1b1f1
Add special tokens argument added
divya-kumari32 Aug 14, 2025
cbca16d
Merge branch 'garrett361:main' into main
divya-kumari32 Aug 15, 2025
0b796e0
V7 tools string type change
divya-kumari32 Aug 22, 2025
56b8a3a
Changed asst_tag for think
divya-kumari32 Aug 26, 2025
25d4929
Changed sft_span fn wrt reviews
divya-kumari32 Aug 26, 2025
3e42f9e
function Changes per reviews
divya-kumari32 Aug 26, 2025
ae2c860
Update open_instruct/dataset_transformation.py
divya-kumari32 Aug 26, 2025
0939271
function Changes per reviews
divya-kumari32 Aug 26, 2025
9cfa6b2
ruff changes
divya-kumari32 Aug 26, 2025
3e74717
Update dataset_transformation.py
divya-kumari32 Aug 26, 2025
2ad362a
ruff changes
divya-kumari32 Aug 26, 2025
86beca8
Merge branch 'think-masking' of https://github.com/divya-kumari32/ope…
divya-kumari32 Aug 26, 2025
41d59c4
Added try/catch for json loads
divya-kumari32 Aug 26, 2025
17d9dba
Added check_sample flag
divya-kumari32 Aug 26, 2025
896932d
Merge branch 'think-masking' of https://github.com/divya-kumari32/ope…
divya-kumari32 Aug 26, 2025
7e6df5d
Added check_sample flag
divya-kumari32 Aug 26, 2025
45fa9a8
Update open_instruct/dataset_transformation.py
divya-kumari32 Aug 27, 2025
ea32413
Added think tag and changed masking flag
divya-kumari32 Aug 27, 2025
c03d076
Merge branch 'think-masking' of https://github.com/divya-kumari32/ope…
divya-kumari32 Aug 27, 2025
adcb1d2
Added think tag and changed masking flag
divya-kumari32 Aug 27, 2025
23a70f9
Added think tag and changed masking flag
divya-kumari32 Aug 27, 2025
fdb1131
ruff checks
divya-kumari32 Aug 27, 2025
f1b5bde
renamed variables and added more description
divya-kumari32 Sep 24, 2025
ae2b00e
renamed variables and added more description
divya-kumari32 Sep 24, 2025
99c1043
Chages after Fabian's review
divya-kumari32 Sep 24, 2025
988927d
Chages after Fabian's review
divya-kumari32 Sep 24, 2025
78c5947
ruff checks
divya-kumari32 Sep 24, 2025
280025d
Name changes
divya-kumari32 Sep 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,24 +1180,38 @@ def sft_span_seach_mask_out(
max_seq_length: int,
asst_tag: str = "<|start_of_role|>assistant<|end_of_role|>",
Comment thread
divya-kumari32 marked this conversation as resolved.
end_tag: str = "<|end_of_text|>",
think_tag: str = "\n<think>\n",
mask_think_tag: bool = False,
ignore_label: int = -100,
):
"""This function encodes a single example into a format that
can be used for sft training (similar to sft_tulu_tokenize_and_truncate_v1).
Instead of performing label masking iteratively, this function performs
masking via span search and can handle complex chat templates with thinking."""

# Span label masking strategy
# - search spans asst_tag ... end_tag
# - all such spans are left unmasked
# - if an asst_tage is undetected due to tokenization issues
# then a span can be erronously masked
# - to avoid this, use tags that are guarded by special tokens
masking via span search and can handle complex chat templates with thinking.
It dynamically determines the assistant tag based on the presence of a
<think> block in the assistant's response.
Comment thread
divya-kumari32 marked this conversation as resolved.
"""

# Dynamically determine the assistant tag based on the conversation content.
def has_thinking_content(messages):
for message in messages:
if message.get("role") == "assistant":
# Check for an explicit 'thought' field or a '<think>' tag in the content.
if message.get("thought") or (
isinstance(message.get("content"), str) and think_tag.strip() in message["content"]
):
return True
return False

def masking_strategy_span_search(input_ids: torch.tensor, tokenizer):
# some prep
match = lambda x, y: torch.all(x == y)
_asst_tag = tokenizer.encode(asst_tag)
_end_tag = tokenizer.encode(end_tag)
# `asst_tag` is captured from the outer scope's dynamic variable
# By default, tokenizers for models like LLaMA, Mistral, and Phi-2 have add_bos_token=True, which means they # automatically prepend a BOS token. This would shift all token positions by 1 and break any span matching logic.
# Other models like Qwen, etc., have add_bos_token=False by default, so they behave differently — leading to # inconsistent behavior across model families if not explicitly handled.

_asst_tag = tokenizer.encode(asst_tag, add_special_tokens=False)
Comment thread
fabianlim marked this conversation as resolved.
_end_tag = tokenizer.encode(end_tag, add_special_tokens=False)
_asst_tag = torch.tensor([_asst_tag])
_end_tag = torch.tensor([_end_tag])

Expand All @@ -1223,7 +1237,7 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer):
else:
raise ValueError(
f"asst_tag has {num_tokens_asst} tokens, and end_tag has {num_tokens_end} "
"tokens, whereas the example has {num_tokens} tokens. Either "
f"tokens, whereas the example has {num_tokens} tokens. Either "
"the example is invalid or wrong tags have been passed."
)

Expand All @@ -1233,7 +1247,12 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer):
additional_inputs = {}
for k in ["tools", "documents"]:
if k in row:
additional_inputs[k] = row[k]
row_data = row[k]
try:
row_data = json.loads(row_data)
except (json.JSONDecodeError, TypeError) as e:
pass
additional_inputs[k] = row_data

if len(messages) == 0:
raise ValueError("messages field is empty.")
Expand All @@ -1248,6 +1267,12 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer):
**additional_inputs,
)

# If the user has set `mask_think_tag=True` and the current sample is a thinking sample,
# then the <think> token is appended to the base `asst_tag` used for span matching.
# This causes the <think> tag to be masked along with the asst_tag
if mask_think_tag and has_thinking_content(messages):
asst_tag += think_tag

# Assume truncation if hitting the exact max length (for downstream data filtering)
was_truncated = input_ids.shape[1] == max_seq_length
row["was_truncated"] = was_truncated
Expand Down