From eeeb0c423cf16554fb36b6f48a252a1991e05594 Mon Sep 17 00:00:00 2001 From: Xuan-Hong Dang Date: Thu, 18 Sep 2025 09:34:45 -0400 Subject: [PATCH 1/6] add more stats during data preprocessing --- open_instruct/dataset_transformation.py | 54 ++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 394df980c6..4102eaba6c 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1339,6 +1339,10 @@ def sft_tulu_filter_truncated_v1(row: Dict[str, Any], tokenizer: PreTrainedToken and not row.get("was_truncated", False) # and was not truncated ) +def sft_tulu_filter_nothing(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): + # To not apply any data filtering + return True + def preference_tokenize_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenizer): # Extract prompt (all messages except the last one) @@ -1573,6 +1577,7 @@ def rlvr_filter_v1( "sft_tulu_tokenize_and_truncate_v1": (sft_tulu_tokenize_and_truncate_v1, "map"), "sft_span_seach_mask_out": (sft_span_seach_mask_out, "map"), "sft_tulu_filter_v1": (sft_tulu_filter_v1, "filter"), + "sft_tulu_filter_nothing": (sft_tulu_filter_nothing, "filter"), "sft_tulu_filter_truncated_v1": (sft_tulu_filter_truncated_v1, "filter"), "preference_tokenize_v1": (preference_tokenize_v1, "map"), "preference_filter_v1": (preference_filter_v1, "filter"), @@ -1908,16 +1913,29 @@ def load_or_transform_dataset( # Transform each dataset and collect statistics transformed_datasets = [] + total_left_samples = 0 dataset_statistics = [] dataset_order = [] for i, dc in enumerate(dcs): initial_size = len(dc.dataset) if dc.dataset else 0 - print(f"\n\n**** {i + 1}. Processing `{dc.dataset_name}` having {len(dc.dataset):,} samples...") + print(f"\n\n**** {i+1}. Processing `{dc.dataset_name}` with {len(dc.dataset):,} selected top samples...") start_time = time.time() dataset = get_dataset_v1(dc, tc) + total_tokens, avg_tokens, std_tokens = count_total_tokens(dataset) duration = time.time() - start_time + print( + f"\n**** Summary for {i + 1}. {dc.dataset_name}:\n" + f" - No.of input samples: {len(dc.dataset):,}\n" + f" - No. of output samples (after processing): {len(dataset):,}\n" + f" - Total tokens: {total_tokens:,}\n" + f" - Avg tokens per sample: {avg_tokens:,.1f}\n" + f" - Stddev tokens per sample: {std_tokens:.2f}\n" + f" - Processing time: {duration:,.2f} seconds\n" + ) + total_left_samples += len(dataset) transformed_datasets.append(dataset) + # Collect statistics for this dataset stats = { "dataset_name": dc.dataset_name, @@ -1927,6 +1945,7 @@ def load_or_transform_dataset( "instances_filtered": initial_size - len(dataset), "frac_or_num_samples": dc.frac_or_num_samples, "original_dataset_size": dc.original_dataset_size, + "process_time_in_second": int(duration), "is_upsampled": dc.is_upsampled, "upsampling_factor": dc.dataset_range / dc.original_dataset_size if dc.original_dataset_size and dc.original_dataset_size > 0 @@ -1949,7 +1968,26 @@ def load_or_transform_dataset( dataset_statistics.append(stats) dataset_order.append(dc.dataset_name) + print( + f"\n**** Summary for {stats['dataset_name']} ({stats['dataset_split']}) ****\n" + f" - Initial instances: {stats['initial_instances']:,}\n" + f" - Final instances: {stats['final_instances']:,}\n" + f" - Instances filtered: {stats['instances_filtered']:,}\n" + f" - Fraction or number of samples: {stats['frac_or_num_samples']}\n" + f" - Original dataset size: {stats['original_dataset_size']:,}\n" + f" - Is upsampled: {stats['is_upsampled']}\n" + f" - Upsampling factor: {stats['upsampling_factor']:.2f}\n" + f" - Processing time: {stats['process_time_in_second']:,} seconds\n" + + ( + f" - Total tokens: {stats['total_tokens']:,}\n" + f" - Trainable tokens: {stats['trainable_tokens']:,}\n" + f" - Avg tokens per instance: {stats['avg_tokens_per_instance']:,.1f}\n" + if "total_tokens" in stats + else "" + ) + ) + print(f"\n**** TOTAL NUM.SAMPLES AFTER DATA TRANSFORMATION: {total_left_samples:,} ****\n") # Combine datasets combined_dataset = concatenate_datasets(transformed_datasets) @@ -1979,6 +2017,20 @@ def load_or_transform_dataset( return loaded_dataset, None +def count_total_tokens(dataset): + # count num.tokens per ds: + def get_token_count(row): + return {"num_tokens": len(row[INPUT_IDS_KEY])} + + num_proc = int(float(os.environ.get("BEAKER_ASSIGNED_CPU_COUNT", multiprocessing.cpu_count()))) + token_counts = dataset.map(get_token_count, num_proc=num_proc) + total_tokens = sum(token_counts["num_tokens"]) + token_lengths = torch.tensor(token_counts["num_tokens"], dtype=torch.float) + + total_tokens = token_lengths.sum().item() + avg_tokens = token_lengths.mean().item() + std_tokens = token_lengths.std(unbiased=False).item() + return total_tokens, avg_tokens, std_tokens def get_cached_dataset( dcs: List[DatasetConfig], tc: TokenizerConfig, From 8153abe7717175ed4d53803c2851278b6a267263 Mon Sep 17 00:00:00 2001 From: Xuan-Hong-dang Date: Thu, 18 Sep 2025 15:51:35 +0000 Subject: [PATCH 2/6] temp update --- open_instruct/dataset_transformation.py | 46 ++++++------------------- 1 file changed, 10 insertions(+), 36 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 4102eaba6c..2bcc40ff32 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1696,8 +1696,8 @@ def select_samples(self, target_size: int): indices.extend(extra_indices.tolist()) print( - f"Upsampling dataset {self.dataset_name} from {original_size} to {target_size} samples " - f"({full_repeats} full repeats + {extra_samples} random samples)" + f"\n Upsampling dataset {self.dataset_name} from {original_size:,} to {target_size:,} samples " + f"({full_repeats} full repeats + {extra_samples:,} random samples)" ) return self.dataset.select(indices) @@ -1744,7 +1744,7 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): fn_kwargs=fn_kwargs, remove_columns=[col for col in dataset.column_names if col not in target_columns], num_proc=get_num_proc(len(dataset), num_proc, APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU), - load_from_cache_file=False, # force running from scratch (to ensure consistency across multiple datafiles) + load_from_cache_file=True, # set False to force running from scratch (to ensure consistency across multiple datafiles) ) elif fn_type == "filter": dataset = dataset.filter( @@ -1918,24 +1918,13 @@ def load_or_transform_dataset( dataset_order = [] for i, dc in enumerate(dcs): initial_size = len(dc.dataset) if dc.dataset else 0 - print(f"\n\n**** {i+1}. Processing `{dc.dataset_name}` with {len(dc.dataset):,} selected top samples...") + print(f"\n\n**** {i+1}. Processing `{dc.dataset_name}` with {len(dc.dataset):,} samples...") start_time = time.time() dataset = get_dataset_v1(dc, tc) - total_tokens, avg_tokens, std_tokens = count_total_tokens(dataset) duration = time.time() - start_time - print( - f"\n**** Summary for {i + 1}. {dc.dataset_name}:\n" - f" - No.of input samples: {len(dc.dataset):,}\n" - f" - No. of output samples (after processing): {len(dataset):,}\n" - f" - Total tokens: {total_tokens:,}\n" - f" - Avg tokens per sample: {avg_tokens:,.1f}\n" - f" - Stddev tokens per sample: {std_tokens:.2f}\n" - f" - Processing time: {duration:,.2f} seconds\n" - ) total_left_samples += len(dataset) transformed_datasets.append(dataset) - # Collect statistics for this dataset stats = { "dataset_name": dc.dataset_name, @@ -1945,7 +1934,7 @@ def load_or_transform_dataset( "instances_filtered": initial_size - len(dataset), "frac_or_num_samples": dc.frac_or_num_samples, "original_dataset_size": dc.original_dataset_size, - "process_time_in_second": int(duration), + "process_time_in_second": duration, "is_upsampled": dc.is_upsampled, "upsampling_factor": dc.dataset_range / dc.original_dataset_size if dc.original_dataset_size and dc.original_dataset_size > 0 @@ -1969,15 +1958,14 @@ def load_or_transform_dataset( dataset_statistics.append(stats) dataset_order.append(dc.dataset_name) print( - f"\n**** Summary for {stats['dataset_name']} ({stats['dataset_split']}) ****\n" + f"\n**** Summary for {i+1}. {stats['dataset_name']} ({stats['dataset_split']}) ****\n" + f" - Original dataset size: {stats['original_dataset_size']:,}\n" f" - Initial instances: {stats['initial_instances']:,}\n" + f" - Fraction or number of samples: {stats['frac_or_num_samples']}\n" + f" - Is upsampled: {stats['is_upsampled']} - Upsampling factor: {stats['upsampling_factor']:.2f}\n" f" - Final instances: {stats['final_instances']:,}\n" f" - Instances filtered: {stats['instances_filtered']:,}\n" - f" - Fraction or number of samples: {stats['frac_or_num_samples']}\n" - f" - Original dataset size: {stats['original_dataset_size']:,}\n" - f" - Is upsampled: {stats['is_upsampled']}\n" - f" - Upsampling factor: {stats['upsampling_factor']:.2f}\n" - f" - Processing time: {stats['process_time_in_second']:,} seconds\n" + f" - Processing time: {stats['process_time_in_second']:,.2f} seconds\n" + ( f" - Total tokens: {stats['total_tokens']:,}\n" f" - Trainable tokens: {stats['trainable_tokens']:,}\n" @@ -2017,20 +2005,6 @@ def load_or_transform_dataset( return loaded_dataset, None -def count_total_tokens(dataset): - # count num.tokens per ds: - def get_token_count(row): - return {"num_tokens": len(row[INPUT_IDS_KEY])} - - num_proc = int(float(os.environ.get("BEAKER_ASSIGNED_CPU_COUNT", multiprocessing.cpu_count()))) - token_counts = dataset.map(get_token_count, num_proc=num_proc) - total_tokens = sum(token_counts["num_tokens"]) - token_lengths = torch.tensor(token_counts["num_tokens"], dtype=torch.float) - - total_tokens = token_lengths.sum().item() - avg_tokens = token_lengths.mean().item() - std_tokens = token_lengths.std(unbiased=False).item() - return total_tokens, avg_tokens, std_tokens def get_cached_dataset( dcs: List[DatasetConfig], tc: TokenizerConfig, From ddfe7b9b18c452205bfd158adb5b748af9c393e1 Mon Sep 17 00:00:00 2001 From: Xuan-Hong-dang Date: Thu, 18 Sep 2025 17:06:13 +0000 Subject: [PATCH 3/6] revise stats. logging during data preprocessing --- open_instruct/dataset_transformation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 2bcc40ff32..50f5899f2b 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1744,7 +1744,7 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): fn_kwargs=fn_kwargs, remove_columns=[col for col in dataset.column_names if col not in target_columns], num_proc=get_num_proc(len(dataset), num_proc, APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU), - load_from_cache_file=True, # set False to force running from scratch (to ensure consistency across multiple datafiles) + load_from_cache_file=False, # force running from scratch (to ensure consistency across multiple datafiles) ) elif fn_type == "filter": dataset = dataset.filter( From 0ef5ff79f9803b79b9ae3e7d30d9d0357c5df772 Mon Sep 17 00:00:00 2001 From: Xuan-Hong-dang Date: Wed, 15 Oct 2025 17:22:29 +0000 Subject: [PATCH 4/6] to train g4l + tokenization resumed from cache/dataset --- open_instruct/dataset_transformation.py | 33 ++++++++++++++----------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 50f5899f2b..dbadfc4c14 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1233,7 +1233,11 @@ def masking_strategy_span_search(input_ids: torch.tensor, tokenizer): additional_inputs = {} for k in ["tools", "documents"]: if k in row: - additional_inputs[k] = row[k] + if k == "tools": + if isinstance(row[k], str) and len(row[k]) > 0: + additional_inputs[k] = json.loads(row[k]) + else: + additional_inputs[k] = row[k] if len(messages) == 0: raise ValueError("messages field is empty.") @@ -1744,7 +1748,7 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): fn_kwargs=fn_kwargs, remove_columns=[col for col in dataset.column_names if col not in target_columns], num_proc=get_num_proc(len(dataset), num_proc, APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU), - load_from_cache_file=False, # force running from scratch (to ensure consistency across multiple datafiles) + load_from_cache_file=True, # False to force running from scratch (to ensure consistency across multiple datafiles) ) elif fn_type == "filter": dataset = dataset.filter( @@ -1942,18 +1946,19 @@ def load_or_transform_dataset( } # Count tokens if the dataset has been tokenized - if INPUT_IDS_KEY in dataset.column_names: - total_tokens = 0 - trainable_tokens = 0 - for sample in dataset: - tokens = len(sample[INPUT_IDS_KEY]) - total_tokens += tokens - if LABELS_KEY in sample: - trainable_tokens += sum(1 for label in sample[LABELS_KEY] if label != -100) - - stats["total_tokens"] = total_tokens - stats["trainable_tokens"] = trainable_tokens - stats["avg_tokens_per_instance"] = total_tokens / len(dataset) if len(dataset) > 0 else 0 + # #== This part takes quite a bit of time for large datasets, so it's better to make it optional (via parameter): + # if INPUT_IDS_KEY in dataset.column_names: + # total_tokens = 0 + # trainable_tokens = 0 + # for sample in dataset: + # tokens = len(sample[INPUT_IDS_KEY]) + # total_tokens += tokens + # if LABELS_KEY in sample: + # trainable_tokens += sum(1 for label in sample[LABELS_KEY] if label != -100) + + # stats["total_tokens"] = total_tokens + # stats["trainable_tokens"] = trainable_tokens + # stats["avg_tokens_per_instance"] = total_tokens / len(dataset) if len(dataset) > 0 else 0 dataset_statistics.append(stats) dataset_order.append(dc.dataset_name) From 51b05ad01ff08ea1ca7798975883c01f3454fa97 Mon Sep 17 00:00:00 2001 From: Xuan-Hong-dang Date: Wed, 22 Oct 2025 20:30:17 +0000 Subject: [PATCH 5/6] revise script for token counts during data tokenization on sample sets and rescale to approximate for full ds --- open_instruct/dataset_transformation.py | 52 ++++++++++++++++++------- open_instruct/finetune.py | 3 ++ 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index dbadfc4c14..498b6f9187 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -50,6 +50,7 @@ from dataclasses import asdict, dataclass, field from functools import cached_property, partial from typing import Any, Dict, List, Literal, Optional, Tuple, Union +import random from transformers.training_args import _convert_str_dict import numpy as np @@ -1748,7 +1749,7 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): fn_kwargs=fn_kwargs, remove_columns=[col for col in dataset.column_names if col not in target_columns], num_proc=get_num_proc(len(dataset), num_proc, APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU), - load_from_cache_file=True, # False to force running from scratch (to ensure consistency across multiple datafiles) + load_from_cache_file=True, #== False to force running from scratch (to ensure consistency across multiple datafiles) ) elif fn_type == "filter": dataset = dataset.filter( @@ -1895,6 +1896,7 @@ def load_or_transform_dataset( keep_in_memory: bool = False, ) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]: """Load dataset from local cache if it exists, otherwise transform and cache it locally.""" + cache_path = self.get_cache_path() # Check if the cache exists @@ -1924,7 +1926,7 @@ def load_or_transform_dataset( initial_size = len(dc.dataset) if dc.dataset else 0 print(f"\n\n**** {i+1}. Processing `{dc.dataset_name}` with {len(dc.dataset):,} samples...") start_time = time.time() - dataset = get_dataset_v1(dc, tc) + dataset = get_dataset_v1(dc, tc) #== tokenize/transform ds in dc (possibly load from cache if existing) duration = time.time() - start_time total_left_samples += len(dataset) transformed_datasets.append(dataset) @@ -1946,19 +1948,36 @@ def load_or_transform_dataset( } # Count tokens if the dataset has been tokenized - # #== This part takes quite a bit of time for large datasets, so it's better to make it optional (via parameter): - # if INPUT_IDS_KEY in dataset.column_names: - # total_tokens = 0 - # trainable_tokens = 0 - # for sample in dataset: - # tokens = len(sample[INPUT_IDS_KEY]) - # total_tokens += tokens - # if LABELS_KEY in sample: - # trainable_tokens += sum(1 for label in sample[LABELS_KEY] if label != -100) - - # stats["total_tokens"] = total_tokens - # stats["trainable_tokens"] = trainable_tokens - # stats["avg_tokens_per_instance"] = total_tokens / len(dataset) if len(dataset) > 0 else 0 + # #== This token count often takes long time. So, let N be the total samples, count tokens from: + # At most 0.005*N (or 0.5%) or 5k randomly selected samples + # Or the whole ds if its total samples is less than 5k + + + if INPUT_IDS_KEY in dataset.column_names: + total_tokens = 0 + trainable_tokens = 0 + + # Determine sample size: use 0.5% of the dataset or at least 5000 samples + sample_size = max(int(0.005 * len(dataset)), 5000) + sample_size = min(sample_size, len(dataset)) # cap at dataset size if smaller than 5K + + # Randomly sample indices + sample_indices = random.sample(range(len(dataset)), sample_size) + + # Accumulate token counts from the sampled subset + for idx in sample_indices: + sample = dataset[idx] + tokens = len(sample[INPUT_IDS_KEY]) + total_tokens += tokens + if LABELS_KEY in sample: + # Count only tokens that are not ignored (-100) + trainable_tokens += sum(1 for label in sample[LABELS_KEY] if label != -100) + + #== Rescale statistics to approximate values for the full dataset + scale_factor = len(dataset) / sample_size + stats["total_tokens"] = int(total_tokens * scale_factor) + stats["trainable_tokens"] = int(trainable_tokens * scale_factor) + stats["avg_tokens_per_instance"] = total_tokens / sample_size if sample_size > 0 else 0 dataset_statistics.append(stats) dataset_order.append(dc.dataset_name) @@ -2091,6 +2110,7 @@ def get_cached_dataset_tulu_with_statistics( print(f"Dataset {dataset_name}: {original_size} -> {new_range} samples (factor: {frac_or_num_samples})") dataset_config.update_range(new_range) dcs.append(dataset_config) + #== Geneate a deterministic hash of both configs dct and tc for caching (folder name) dataset_config_hash = compute_config_hash(dcs, tc) if dataset_cache_mode == "local": cache = LocalDatasetTransformationCache( @@ -2098,6 +2118,8 @@ def get_cached_dataset_tulu_with_statistics( ) elif dataset_cache_mode == "hf": cache = DatasetTransformationCache(config_hash=dataset_config_hash, hf_entity=hf_entity) + + #== Either load existing one or perform tokenization+trainsformation: return cache.load_or_transform_dataset( dcs, tc, diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index dce84e657a..c0aedec611 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -495,6 +495,8 @@ def main(args: FlatArguments, tc: TokenizerConfig): if args.dataset_mixer is not None: args.dataset_mixer_list = [item for pair in args.dataset_mixer.items() for item in pair] + + #== data tokenization: with accelerator.main_process_first(): transform_fn_args = [{"max_seq_length": args.max_seq_length}, {}] train_dataset = get_cached_dataset_tulu( @@ -518,6 +520,7 @@ def main(args: FlatArguments, tc: TokenizerConfig): # visualize_token(train_dataset[0][INPUT_IDS_KEY], tokenizer) visualize_token_label(train_dataset[0][INPUT_IDS_KEY], train_dataset[0][LABELS_KEY], tokenizer) + #== not moving to model training if cache_dataset_only is set to True (i.e., stop after data tokenization) if args.cache_dataset_only: return From d79b4b08a6866752c6b709cb7295cba4c0cb1de2 Mon Sep 17 00:00:00 2001 From: Xuan-Hong-dang Date: Tue, 28 Oct 2025 20:48:17 +0000 Subject: [PATCH 6/6] approx token counts for CP training --- open_instruct/dataset_transformation.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 498b6f9187..0fbb71398f 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -1922,6 +1922,10 @@ def load_or_transform_dataset( total_left_samples = 0 dataset_statistics = [] dataset_order = [] + + all_ds_total_tokens = 0 + all_ds_trainable_tokens = 0 + for i, dc in enumerate(dcs): initial_size = len(dc.dataset) if dc.dataset else 0 print(f"\n\n**** {i+1}. Processing `{dc.dataset_name}` with {len(dc.dataset):,} samples...") @@ -1951,15 +1955,14 @@ def load_or_transform_dataset( # #== This token count often takes long time. So, let N be the total samples, count tokens from: # At most 0.005*N (or 0.5%) or 5k randomly selected samples # Or the whole ds if its total samples is less than 5k - - if INPUT_IDS_KEY in dataset.column_names: + if INPUT_IDS_KEY in dataset.column_names and len(dataset) > 0: total_tokens = 0 trainable_tokens = 0 # Determine sample size: use 0.5% of the dataset or at least 5000 samples sample_size = max(int(0.005 * len(dataset)), 5000) - sample_size = min(sample_size, len(dataset)) # cap at dataset size if smaller than 5K + sample_size = min(sample_size, len(dataset)) # cap at dataset size if smaller than 5000 # Randomly sample indices sample_indices = random.sample(range(len(dataset)), sample_size) @@ -1977,7 +1980,10 @@ def load_or_transform_dataset( scale_factor = len(dataset) / sample_size stats["total_tokens"] = int(total_tokens * scale_factor) stats["trainable_tokens"] = int(trainable_tokens * scale_factor) - stats["avg_tokens_per_instance"] = total_tokens / sample_size if sample_size > 0 else 0 + stats["avg_tokens_per_instance"] = round(total_tokens / sample_size, 2) + + all_ds_total_tokens += stats["total_tokens"] + all_ds_trainable_tokens += stats["trainable_tokens"] dataset_statistics.append(stats) dataset_order.append(dc.dataset_name) @@ -1999,7 +2005,7 @@ def load_or_transform_dataset( ) ) - print(f"\n**** TOTAL NUM.SAMPLES AFTER DATA TRANSFORMATION: {total_left_samples:,} ****\n") + print(f"\n**** TOTAL NUM.SAMPLES AFTER DATA TRANSFORMATION: {total_left_samples:,} with TOTAL #TOKENS: {all_ds_total_tokens:,} #TRAINABLE TOKENS: {all_ds_trainable_tokens:,} ****\n") # Combine datasets combined_dataset = concatenate_datasets(transformed_datasets)