diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 0c21cfab00..e39e8a1c89 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -913,11 +913,59 @@ def get_tokenizer_tulu_v2_2(tc: "TokenizerConfig"): return tokenizer +def get_tokenizer_tulu_no_pad_tok_addition(tc: "TokenizerConfig"): + config = AutoConfig.from_pretrained(tc.tokenizer_name_or_path, revision=tc.tokenizer_revision) + # @vwxyzjn: "olmo" handles both `olmo2` and `olmoe`. + if "olmo" in config.model_type: + if "olmo" in tc.chat_template_name: + assert not tc.add_bos, "For newer OLMo chat templates, you must *not* run with `--add_bos`." + else: + assert tc.add_bos, "For OLMo, you must run with `--add_bos`." + assert tc.use_fast, "For OLMo, you must use fast tokenizer." + + tokenizer = AutoTokenizer.from_pretrained( + tc.tokenizer_name_or_path, + revision=tc.tokenizer_revision, + trust_remote_code=tc.trust_remote_code, + use_fast=tc.use_fast, + ) + + # set the tokenizer chat template to the training format + # this will be used for encoding the training examples + # and saved together with the tokenizer to be used later. + if tc.chat_template_name in CHAT_TEMPLATES: + tokenizer.chat_template = CHAT_TEMPLATES[tc.chat_template_name] + else: + try: + if is_jinja_file(tc.chat_template_name): + with open(tc.chat_template_name) as f: + tokenizer.chat_template = f.read() + else: + tokenizer.chat_template = AutoTokenizer.from_pretrained( + tc.tokenizer_name_or_path, revision=tc.tokenizer_revision + ).chat_template + except Exception: + raise ValueError(f"Could not find chat template for {tc.tokenizer_name_or_path}.") + + if tc.add_bos: + if tokenizer.chat_template.startswith("{{ bos_token }}") or ( + tokenizer.bos_token is not None and tokenizer.chat_template.startswith(tokenizer.bos_token) + ): + raise ValueError( + "You specified add_bos=True, but the chat template already has a bos_token at the beginning." + ) + # also add bos in the chat template if not already there + tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template + + return tokenizer + + GET_TOKENIZER_FN = { "get_tokenizer_simple_v1": get_tokenizer_simple_v1, "get_tokenizer_tulu_v1": get_tokenizer_tulu_v1, # old version, see https://github.com/allenai/open-instruct/pull/570 "get_tokenizer_tulu_v2_1": get_tokenizer_tulu_v2_1, "get_tokenizer_tulu_v2_2": get_tokenizer_tulu_v2_2, + "get_tokenizer_tulu_no_pad_tok_addition": get_tokenizer_tulu_no_pad_tok_addition, } DEFAULT_SFT_MESSAGES_KEY = "messages"