From 651bee02ad5ab7804084b6292d574f95d7f67bbf Mon Sep 17 00:00:00 2001 From: Howard Yen Date: Fri, 26 Dec 2025 22:13:24 -0500 Subject: [PATCH] add support for csv files, minor bug fixes for random indices and handling empty token sets in tokenize --- datatools/io_utils.py | 28 ++++++++++++++++++++++++++++ datatools/load.py | 10 +++++++++- datatools/process.py | 2 +- datatools/scripts/pack.py | 22 ++++++++++++---------- datatools/scripts/tokenize.py | 13 +++++++++---- 5 files changed, 59 insertions(+), 16 deletions(-) diff --git a/datatools/io_utils.py b/datatools/io_utils.py index fa6af69..e1cb261 100644 --- a/datatools/io_utils.py +++ b/datatools/io_utils.py @@ -1,6 +1,7 @@ import json import os import io +import random from typing import Any, Dict, Union, List, Optional from collections.abc import Sequence @@ -62,6 +63,11 @@ def shard(cls, dataset: Array, shard_id: int, num_shards: int): shard_indices = np.linspace(0, N, num_shards + 1) return cls(dataset, range(int(shard_indices[shard_id]), int(shard_indices[shard_id + 1]))) + + @classmethod + def sample(cls, dataset: Array, num_samples: int, seed: int = 0): + rng = random.Random(seed) + return cls(dataset, rng.sample(range(len(dataset)), num_samples)) def __len__(self) -> int: return len(self.indices) @@ -190,6 +196,28 @@ def get_item(self, idx: int) -> Dict[str, Any]: return json.loads(self.lines[idx]) +class CSVDataset(Array): + def __init__(self, paths: List[Union[UPath, str]], is_tsv: bool = False): + self.paths = paths + self.is_tsv = is_tsv + + dfs = [] + for path in paths: + dfs.append(pd.read_csv(path, sep='\t' if is_tsv else ',')) + + self.df = pd.concat(dfs) + + def __len__(self) -> int: + return len(self.df) + + @property + def size(self) -> int: + return len(self.df) + + def get_item(self, idx: int) -> Dict[str, Any]: + return self.df.iloc[idx].to_dict() + + class PyArrowDataset(Array): """PyArrow-based dataset that supports parquet and arrow files with local and S3 paths.""" diff --git a/datatools/load.py b/datatools/load.py index 77ddaa4..234368b 100644 --- a/datatools/load.py +++ b/datatools/load.py @@ -7,7 +7,7 @@ from upath import UPath import glob -from datatools.io_utils import LocalDatasets, JsonlDataset, is_remote_path, has_compressed_mds_files, RemoteDatasets, PyArrowDataset +from datatools.io_utils import LocalDatasets, JsonlDataset, is_remote_path, has_compressed_mds_files, RemoteDatasets, PyArrowDataset, CSVDataset def _expand_glob_patterns(input_paths: List[Union[UPath, str]]) -> List[UPath]: """Expand glob patterns in input paths for both local and remote paths.""" @@ -103,6 +103,10 @@ def load(*input_paths: List[Union[UPath, str]], options: Optional[LoadOptions] = if suffix in [".arrow", ".parquet", ".npy", ".jsonl"]: input_type = suffix[1:] break + # Attempt to load json as jsonl + if suffix == ".json": + input_type = "jsonl" + break if input_type == "mosaic": if any(is_remote_path(path) or has_compressed_mds_files(path) for path in input_paths): @@ -111,6 +115,10 @@ def load(*input_paths: List[Union[UPath, str]], options: Optional[LoadOptions] = return LocalDatasets(input_paths) elif input_type == "jsonl": return JsonlDataset(input_paths) + elif input_type == "csv": + return CSVDataset(input_paths) + elif input_type == "tsv": + return CSVDataset(input_paths, is_tsv=True) elif input_type == "npy": return np.concatenate([np.load(str(path)) for path in input_paths]) elif input_type in {"parquet", "arrow"}: diff --git a/datatools/process.py b/datatools/process.py index 41952bb..9a77a93 100644 --- a/datatools/process.py +++ b/datatools/process.py @@ -182,7 +182,7 @@ def load_indices(options): if options.index_range is not None: logger.info(f"Using indices from {options.index_range[0]} to {options.index_range[1]}") - indices = range(*options.index_range) + indices = np.arange(*options.index_range) return indices diff --git a/datatools/scripts/pack.py b/datatools/scripts/pack.py index f3f5115..b58deb3 100644 --- a/datatools/scripts/pack.py +++ b/datatools/scripts/pack.py @@ -210,7 +210,7 @@ def pack_fn(data: Array, field: add_special_tokens(np.array(item[field], dtype=np.uint32), options, bos=True, eos=True) for field in options.other_fields } - + if options.split_by_lengths: while len(input_ids) >= sorted_lengths[-1]: # From longest to shortest @@ -219,7 +219,7 @@ def pack_fn(data: Array, if len(input_ids) >= target_len) target_subset = subset / f"{target_len}-{options.pack_length}" - + other_iterators = { field: iter(list(other_buffers[field][target_len].process(seq[:target_len]))) for field, seq in other_seqs.items() @@ -227,18 +227,18 @@ def pack_fn(data: Array, for item in buffers[target_len].process(input_ids[:target_len]): if options.domain_field: item.update({options.domain_field: str(target_subset)}) - + for field, iterator in other_iterators.items(): item[field] = next(iterator)[field] assert len(item[field]) == len(item[options.token_field]) - + yield target_subset, item if options.intact: break input_ids = add_special_tokens(input_ids[target_len - options.overlap:], options, mos=True) - + for field, iterator in other_iterators.items(): other_seqs[field] = add_special_tokens(other_seqs[field][target_len - options.overlap:], options, mos=True) else: @@ -246,15 +246,15 @@ def pack_fn(data: Array, field: iter(list(other_buffers[field][subset].process(seq))) for field, seq in other_seqs.items() } - + for item in buffers[subset].process(input_ids): if options.domain_field: item.update({options.domain_field: str(subset)}) - + for field, iterator in other_iterators.items(): item[field] = next(iterator)[field] assert len(item[field]) == len(item[options.token_field]) - + yield subset, item @@ -267,7 +267,7 @@ def main(): parser.add_arguments(PackOptions, dest="pack_options") parser.add_arguments(LoadOptions, dest="load_options") parser.add_arguments(ProcessOptions, dest="process_options") - + parser.add_argument("-x", "--shuffle", action="store_true", help="Shuffle the dataset") parser.add_argument("--seed", type=int, default=42, help="Shuffle seed") @@ -277,13 +277,15 @@ def main(): dataset = load(*args.inputs, options=args.load_options) N = len(dataset) print(f"Loaded dataset with {N} samples") - + if args.shuffle: indices = load_indices(args.process_options) if indices is None: indices = np.arange(N) np.random.seed(args.seed) args.process_options.indices = indices[np.random.permutation(len(indices))] + args.process_options.index_path = None + args.process_options.index_range = None process(dataset, partial(pack_fn, options=args.pack_options), diff --git a/datatools/scripts/tokenize.py b/datatools/scripts/tokenize.py index a21711d..8ffe045 100644 --- a/datatools/scripts/tokenize.py +++ b/datatools/scripts/tokenize.py @@ -54,7 +54,7 @@ def load_tokenizer_encoder(options: TokenizeOptions): from datatools.scripts.tokenizers.llama3_tokenizer import Tokenizer tokenizer = Tokenizer(str(Path(__file__).parent / "tokenizers" / "llama3_tokenizer.model")) from datatools.scripts.tokenizers.llama3_tokenizer import ChatFormat - + if options.chat_template: chat_format = ChatFormat(tokenizer) def encode_fn(item): @@ -84,7 +84,7 @@ def encode_fn(item): return tokens return encode_fn - + def tokenize_fn(data: Array, @@ -96,7 +96,7 @@ def tokenize_fn(data: Array, for i in tqdm(range(len(data)), desc=f"Process {process_id}"): item = data[i] domain = item[options.domain_by] if options.domain_by is not None else options.domain - + if options.chat_template and options.chat_assistant_masking: tokens, masks = encode_fn(item) @@ -109,7 +109,12 @@ def tokenize_fn(data: Array, output_item = { options.token_field: np.array(tokens, dtype=np.uint32), } - + + if len(tokens) == 0: + # writing an array of length 0 will throw an error by MDS, which is undesirable as the rest of the data will be abandoned + # instead, we will write with a dummy token (0), and let the later user filter this out + output_item[options.token_field] = np.array([0], dtype=np.uint32) + if options.length_field: output_item[options.length_field] = len(tokens) if options.domain_field: