diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 394df980c6..0fbb71398f 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -50,6 +50,7 @@ from dataclasses import asdict, dataclass, field from functools import cached_property, partial from typing import Any, Dict, List, Literal, Optional, Tuple, Union +import random from transformers.training_args import _convert_str_dict import numpy as np @@ -1233,7 +1234,11 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): additional_inputs = {} for k in ["tools", "documents"]: if k in row: - additional_inputs[k] = row[k] + if k == "tools": + if isinstance(row[k], str) and len(row[k]) > 0: + additional_inputs[k] = json.loads(row[k]) + else: + additional_inputs[k] = row[k] if len(messages) == 0: raise ValueError("messages field is empty.") @@ -1339,6 +1344,10 @@ def sft_tulu_filter_truncated_v1(row: Dict[str, Any], tokenizer: PreTrainedToken and not row.get("was_truncated", False) # and was not truncated ) +def sft_tulu_filter_nothing(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): + # To not apply any data filtering + return True + def preference_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): # Extract prompt (all messages except the last one) @@ -1573,6 +1582,7 @@ def rlvr_filter_v1( "sft_tulu_tokenize_and_truncate_v1": (sft_tulu_tokenize_and_truncate_v1, "map"), "sft_span_seach_mask_out": (sft_span_seach_mask_out, "map"), "sft_tulu_filter_v1": (sft_tulu_filter_v1, "filter"), + "sft_tulu_filter_nothing": (sft_tulu_filter_nothing, "filter"), "sft_tulu_filter_truncated_v1": (sft_tulu_filter_truncated_v1, "filter"), "preference_tokenize_v1": (preference_tokenize_v1, "map"), "preference_filter_v1": (preference_filter_v1, "filter"), @@ -1691,8 +1701,8 @@ def select_samples(self, target_size: int): indices.extend(extra_indices.tolist()) print( - f"Upsampling dataset {self.dataset_name} from {original_size} to {target_size} samples " - f"({full_repeats} full repeats + {extra_samples} random samples)" + f"\n Upsampling dataset {self.dataset_name} from {original_size:,} to {target_size:,} samples " + f"({full_repeats} full repeats + {extra_samples:,} random samples)" ) return self.dataset.select(indices) @@ -1739,7 +1749,7 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): fn_kwargs=fn_kwargs, remove_columns=[col for col in dataset.column_names if col not in target_columns], num_proc=get_num_proc(len(dataset), num_proc, APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU), - load_from_cache_file=False, # force running from scratch (to ensure consistency across multiple datafiles) + load_from_cache_file=True, #== False to force running from scratch (to ensure consistency across multiple datafiles) ) elif fn_type == "filter": dataset = dataset.filter( @@ -1886,6 +1896,7 @@ def load_or_transform_dataset( keep_in_memory: bool = False, ) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]: """Load dataset from local cache if it exists, otherwise transform and cache it locally.""" + cache_path = self.get_cache_path() # Check if the cache exists @@ -1908,14 +1919,20 @@ def load_or_transform_dataset( # Transform each dataset and collect statistics transformed_datasets = [] + total_left_samples = 0 dataset_statistics = [] dataset_order = [] + + all_ds_total_tokens = 0 + all_ds_trainable_tokens = 0 + for i, dc in enumerate(dcs): initial_size = len(dc.dataset) if dc.dataset else 0 - print(f"\n\n**** {i + 1}. Processing `{dc.dataset_name}` having {len(dc.dataset):,} samples...") + print(f"\n\n**** {i+1}. Processing `{dc.dataset_name}` with {len(dc.dataset):,} samples...") start_time = time.time() - dataset = get_dataset_v1(dc, tc) + dataset = get_dataset_v1(dc, tc) #== tokenize/transform ds in dc (possibly load from cache if existing) duration = time.time() - start_time + total_left_samples += len(dataset) transformed_datasets.append(dataset) # Collect statistics for this dataset @@ -1927,6 +1944,7 @@ def load_or_transform_dataset( "instances_filtered": initial_size - len(dataset), "frac_or_num_samples": dc.frac_or_num_samples, "original_dataset_size": dc.original_dataset_size, + "process_time_in_second": duration, "is_upsampled": dc.is_upsampled, "upsampling_factor": dc.dataset_range / dc.original_dataset_size if dc.original_dataset_size and dc.original_dataset_size > 0 @@ -1934,22 +1952,60 @@ def load_or_transform_dataset( } # Count tokens if the dataset has been tokenized - if INPUT_IDS_KEY in dataset.column_names: + # #== This token count often takes long time. So, let N be the total samples, count tokens from: + # At most 0.005*N (or 0.5%) or 5k randomly selected samples + # Or the whole ds if its total samples is less than 5k + + if INPUT_IDS_KEY in dataset.column_names and len(dataset) > 0: total_tokens = 0 trainable_tokens = 0 - for sample in dataset: + + # Determine sample size: use 0.5% of the dataset or at least 5000 samples + sample_size = max(int(0.005 * len(dataset)), 5000) + sample_size = min(sample_size, len(dataset)) # cap at dataset size if smaller than 5000 + + # Randomly sample indices + sample_indices = random.sample(range(len(dataset)), sample_size) + + # Accumulate token counts from the sampled subset + for idx in sample_indices: + sample = dataset[idx] tokens = len(sample[INPUT_IDS_KEY]) total_tokens += tokens if LABELS_KEY in sample: + # Count only tokens that are not ignored (-100) trainable_tokens += sum(1 for label in sample[LABELS_KEY] if label != -100) - stats["total_tokens"] = total_tokens - stats["trainable_tokens"] = trainable_tokens - stats["avg_tokens_per_instance"] = total_tokens / len(dataset) if len(dataset) > 0 else 0 + #== Rescale statistics to approximate values for the full dataset + scale_factor = len(dataset) / sample_size + stats["total_tokens"] = int(total_tokens * scale_factor) + stats["trainable_tokens"] = int(trainable_tokens * scale_factor) + stats["avg_tokens_per_instance"] = round(total_tokens / sample_size, 2) + + all_ds_total_tokens += stats["total_tokens"] + all_ds_trainable_tokens += stats["trainable_tokens"] dataset_statistics.append(stats) dataset_order.append(dc.dataset_name) + print( + f"\n**** Summary for {i+1}. {stats['dataset_name']} ({stats['dataset_split']}) ****\n" + f" - Original dataset size: {stats['original_dataset_size']:,}\n" + f" - Initial instances: {stats['initial_instances']:,}\n" + f" - Fraction or number of samples: {stats['frac_or_num_samples']}\n" + f" - Is upsampled: {stats['is_upsampled']} - Upsampling factor: {stats['upsampling_factor']:.2f}\n" + f" - Final instances: {stats['final_instances']:,}\n" + f" - Instances filtered: {stats['instances_filtered']:,}\n" + f" - Processing time: {stats['process_time_in_second']:,.2f} seconds\n" + + ( + f" - Total tokens: {stats['total_tokens']:,}\n" + f" - Trainable tokens: {stats['trainable_tokens']:,}\n" + f" - Avg tokens per instance: {stats['avg_tokens_per_instance']:,.1f}\n" + if "total_tokens" in stats + else "" + ) + ) + print(f"\n**** TOTAL NUM.SAMPLES AFTER DATA TRANSFORMATION: {total_left_samples:,} with TOTAL #TOKENS: {all_ds_total_tokens:,} #TRAINABLE TOKENS: {all_ds_trainable_tokens:,} ****\n") # Combine datasets combined_dataset = concatenate_datasets(transformed_datasets) @@ -2060,6 +2116,7 @@ def get_cached_dataset_tulu_with_statistics( print(f"Dataset {dataset_name}: {original_size} -> {new_range} samples (factor: {frac_or_num_samples})") dataset_config.update_range(new_range) dcs.append(dataset_config) + #== Geneate a deterministic hash of both configs dct and tc for caching (folder name) dataset_config_hash = compute_config_hash(dcs, tc) if dataset_cache_mode == "local": cache = LocalDatasetTransformationCache( @@ -2067,6 +2124,8 @@ def get_cached_dataset_tulu_with_statistics( ) elif dataset_cache_mode == "hf": cache = DatasetTransformationCache(config_hash=dataset_config_hash, hf_entity=hf_entity) + + #== Either load existing one or perform tokenization+trainsformation: return cache.load_or_transform_dataset( dcs, tc, diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index dce84e657a..c0aedec611 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -495,6 +495,8 @@ def main(args: FlatArguments, tc: TokenizerConfig): if args.dataset_mixer is not None: args.dataset_mixer_list = [item for pair in args.dataset_mixer.items() for item in pair] + + #== data tokenization: with accelerator.main_process_first(): transform_fn_args = [{"max_seq_length": args.max_seq_length}, {}] train_dataset = get_cached_dataset_tulu( @@ -518,6 +520,7 @@ def main(args: FlatArguments, tc: TokenizerConfig): # visualize_token(train_dataset[0][INPUT_IDS_KEY], tokenizer) visualize_token_label(train_dataset[0][INPUT_IDS_KEY], train_dataset[0][LABELS_KEY], tokenizer) + #== not moving to model training if cache_dataset_only is set to True (i.e., stop after data tokenization) if args.cache_dataset_only: return