Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from dataclasses import asdict, dataclass, field
from functools import cached_property, partial
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import random

from transformers.training_args import _convert_str_dict
import numpy as np
Expand Down Expand Up @@ -1233,7 +1234,11 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer):
additional_inputs = {}
for k in ["tools", "documents"]:
if k in row:
additional_inputs[k] = row[k]
if k == "tools":
if isinstance(row[k], str) and len(row[k]) > 0:
additional_inputs[k] = json.loads(row[k])
else:
additional_inputs[k] = row[k]

if len(messages) == 0:
raise ValueError("messages field is empty.")
Expand Down Expand Up @@ -1339,6 +1344,10 @@ def sft_tulu_filter_truncated_v1(row: Dict[str, Any], tokenizer: PreTrainedToken
and not row.get("was_truncated", False) # and was not truncated
)

def sft_tulu_filter_nothing(row: Dict[str, Any], tokenizer: PreTrainedTokenizer):
# To not apply any data filtering
return True


def preference_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer):
# Extract prompt (all messages except the last one)
Expand Down Expand Up @@ -1573,6 +1582,7 @@ def rlvr_filter_v1(
"sft_tulu_tokenize_and_truncate_v1": (sft_tulu_tokenize_and_truncate_v1, "map"),
"sft_span_seach_mask_out": (sft_span_seach_mask_out, "map"),
"sft_tulu_filter_v1": (sft_tulu_filter_v1, "filter"),
"sft_tulu_filter_nothing": (sft_tulu_filter_nothing, "filter"),
"sft_tulu_filter_truncated_v1": (sft_tulu_filter_truncated_v1, "filter"),
"preference_tokenize_v1": (preference_tokenize_v1, "map"),
"preference_filter_v1": (preference_filter_v1, "filter"),
Expand Down Expand Up @@ -1691,8 +1701,8 @@ def select_samples(self, target_size: int):
indices.extend(extra_indices.tolist())

print(
f"Upsampling dataset {self.dataset_name} from {original_size} to {target_size} samples "
f"({full_repeats} full repeats + {extra_samples} random samples)"
f"\n Upsampling dataset {self.dataset_name} from {original_size:,} to {target_size:,} samples "
f"({full_repeats} full repeats + {extra_samples:,} random samples)"
)

return self.dataset.select(indices)
Expand Down Expand Up @@ -1739,7 +1749,7 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig):
fn_kwargs=fn_kwargs,
remove_columns=[col for col in dataset.column_names if col not in target_columns],
num_proc=get_num_proc(len(dataset), num_proc, APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU),
load_from_cache_file=False, # force running from scratch (to ensure consistency across multiple datafiles)
load_from_cache_file=True, #== False to force running from scratch (to ensure consistency across multiple datafiles)
)
elif fn_type == "filter":
dataset = dataset.filter(
Expand Down Expand Up @@ -1886,6 +1896,7 @@ def load_or_transform_dataset(
keep_in_memory: bool = False,
) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]:
"""Load dataset from local cache if it exists, otherwise transform and cache it locally."""

cache_path = self.get_cache_path()

# Check if the cache exists
Expand All @@ -1908,14 +1919,20 @@ def load_or_transform_dataset(

# Transform each dataset and collect statistics
transformed_datasets = []
total_left_samples = 0
dataset_statistics = []
dataset_order = []

all_ds_total_tokens = 0
all_ds_trainable_tokens = 0

for i, dc in enumerate(dcs):
initial_size = len(dc.dataset) if dc.dataset else 0
print(f"\n\n**** {i + 1}. Processing `{dc.dataset_name}` having {len(dc.dataset):,} samples...")
print(f"\n\n**** {i+1}. Processing `{dc.dataset_name}` with {len(dc.dataset):,} samples...")
start_time = time.time()
dataset = get_dataset_v1(dc, tc)
dataset = get_dataset_v1(dc, tc) #== tokenize/transform ds in dc (possibly load from cache if existing)
duration = time.time() - start_time
total_left_samples += len(dataset)
transformed_datasets.append(dataset)

# Collect statistics for this dataset
Expand All @@ -1927,29 +1944,68 @@ def load_or_transform_dataset(
"instances_filtered": initial_size - len(dataset),
"frac_or_num_samples": dc.frac_or_num_samples,
"original_dataset_size": dc.original_dataset_size,
"process_time_in_second": duration,
"is_upsampled": dc.is_upsampled,
"upsampling_factor": dc.dataset_range / dc.original_dataset_size
if dc.original_dataset_size and dc.original_dataset_size > 0
else 1.0,
}

# Count tokens if the dataset has been tokenized
if INPUT_IDS_KEY in dataset.column_names:
# #== This token count often takes long time. So, let N be the total samples, count tokens from:
# At most 0.005*N (or 0.5%) or 5k randomly selected samples
# Or the whole ds if its total samples is less than 5k

if INPUT_IDS_KEY in dataset.column_names and len(dataset) > 0:
total_tokens = 0
trainable_tokens = 0
for sample in dataset:

# Determine sample size: use 0.5% of the dataset or at least 5000 samples
sample_size = max(int(0.005 * len(dataset)), 5000)
sample_size = min(sample_size, len(dataset)) # cap at dataset size if smaller than 5000

# Randomly sample indices
sample_indices = random.sample(range(len(dataset)), sample_size)

# Accumulate token counts from the sampled subset
for idx in sample_indices:
sample = dataset[idx]
tokens = len(sample[INPUT_IDS_KEY])
total_tokens += tokens
if LABELS_KEY in sample:
# Count only tokens that are not ignored (-100)
trainable_tokens += sum(1 for label in sample[LABELS_KEY] if label != -100)

stats["total_tokens"] = total_tokens
stats["trainable_tokens"] = trainable_tokens
stats["avg_tokens_per_instance"] = total_tokens / len(dataset) if len(dataset) > 0 else 0
#== Rescale statistics to approximate values for the full dataset
scale_factor = len(dataset) / sample_size
stats["total_tokens"] = int(total_tokens * scale_factor)
stats["trainable_tokens"] = int(trainable_tokens * scale_factor)
stats["avg_tokens_per_instance"] = round(total_tokens / sample_size, 2)

all_ds_total_tokens += stats["total_tokens"]
all_ds_trainable_tokens += stats["trainable_tokens"]

dataset_statistics.append(stats)
dataset_order.append(dc.dataset_name)
print(
f"\n**** Summary for {i+1}. {stats['dataset_name']} ({stats['dataset_split']}) ****\n"
f" - Original dataset size: {stats['original_dataset_size']:,}\n"
f" - Initial instances: {stats['initial_instances']:,}\n"
f" - Fraction or number of samples: {stats['frac_or_num_samples']}\n"
f" - Is upsampled: {stats['is_upsampled']} - Upsampling factor: {stats['upsampling_factor']:.2f}\n"
f" - Final instances: {stats['final_instances']:,}\n"
f" - Instances filtered: {stats['instances_filtered']:,}\n"
f" - Processing time: {stats['process_time_in_second']:,.2f} seconds\n"
+ (
f" - Total tokens: {stats['total_tokens']:,}\n"
f" - Trainable tokens: {stats['trainable_tokens']:,}\n"
f" - Avg tokens per instance: {stats['avg_tokens_per_instance']:,.1f}\n"
if "total_tokens" in stats
else ""
)
)

print(f"\n**** TOTAL NUM.SAMPLES AFTER DATA TRANSFORMATION: {total_left_samples:,} with TOTAL #TOKENS: {all_ds_total_tokens:,} #TRAINABLE TOKENS: {all_ds_trainable_tokens:,} ****\n")
# Combine datasets
combined_dataset = concatenate_datasets(transformed_datasets)

Expand Down Expand Up @@ -2060,13 +2116,16 @@ def get_cached_dataset_tulu_with_statistics(
print(f"Dataset {dataset_name}: {original_size} -> {new_range} samples (factor: {frac_or_num_samples})")
dataset_config.update_range(new_range)
dcs.append(dataset_config)
#== Geneate a deterministic hash of both configs dct and tc for caching (folder name)
dataset_config_hash = compute_config_hash(dcs, tc)
if dataset_cache_mode == "local":
cache = LocalDatasetTransformationCache(
config_hash=dataset_config_hash, dataset_local_cache_dir=dataset_local_cache_dir
)
elif dataset_cache_mode == "hf":
cache = DatasetTransformationCache(config_hash=dataset_config_hash, hf_entity=hf_entity)

#== Either load existing one or perform tokenization+trainsformation:
return cache.load_or_transform_dataset(
dcs,
tc,
Expand Down
3 changes: 3 additions & 0 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,8 @@ def main(args: FlatArguments, tc: TokenizerConfig):

if args.dataset_mixer is not None:
args.dataset_mixer_list = [item for pair in args.dataset_mixer.items() for item in pair]

#== data tokenization:
with accelerator.main_process_first():
transform_fn_args = [{"max_seq_length": args.max_seq_length}, {}]
train_dataset = get_cached_dataset_tulu(
Expand All @@ -518,6 +520,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
# visualize_token(train_dataset[0][INPUT_IDS_KEY], tokenizer)
visualize_token_label(train_dataset[0][INPUT_IDS_KEY], train_dataset[0][LABELS_KEY], tokenizer)

#== not moving to model training if cache_dataset_only is set to True (i.e., stop after data tokenization)
if args.cache_dataset_only:
return

Expand Down