From 6e706dd3aadd01ba13fde833f57612d50bb00b8a Mon Sep 17 00:00:00 2001 From: Xuan-Hong-dang Date: Wed, 18 Jun 2025 17:42:13 +0000 Subject: [PATCH 1/8] enable granite4 chattemplate --- .gitignore | 6 +- .../chat_templates/ct-granite4.jinja2 | 130 +++++++++++++++ open_instruct/dataset_processor.py | 65 ++++++-- open_instruct/finetune.py | 70 ++++++-- open_instruct/utils.py | 153 ++++++++++++++++++ 5 files changed, 395 insertions(+), 29 deletions(-) create mode 100644 open_instruct/chat_templates/ct-granite4.jinja2 diff --git a/.gitignore b/.gitignore index 423d441d13..5318d1024f 100644 --- a/.gitignore +++ b/.gitignore @@ -146,4 +146,8 @@ dmypy.json .idea/ .vscode -.DS_Store \ No newline at end of file +.DS_Store + + +# script: +open_instruct/finetune-*.py \ No newline at end of file diff --git a/open_instruct/chat_templates/ct-granite4.jinja2 b/open_instruct/chat_templates/ct-granite4.jinja2 new file mode 100644 index 0000000000..e653cc5b08 --- /dev/null +++ b/open_instruct/chat_templates/ct-granite4.jinja2 @@ -0,0 +1,130 @@ +{%- set tools_system_message_prefix = 'You are a helpful assistant with access to the following tools. You may call one or more tools to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n' %} +{%- set tools_system_message_suffix = '\n\n\nFor each tool call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.' %} +{%- set documents_system_message_prefix = 'You are a helpful assistant with access to the following documents. You may use one or more documents to assist with the user query.\n\nYou are given a list of documents within XML tags:\n' %} +{%- set documents_system_message_suffix = '\n\n\nWrite the response to the user\'s input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.' %} +{%- if available_tools is defined and available_tools %} + {%- set tools = available_tools %} +{%- endif %} +{%- set ns = namespace(tools_system_message=tools_system_message_prefix, + documents_system_message=documents_system_message_prefix, + system_message='', + last_query_index=messages|length - 1) %} +{%- if tools %} + {%- for tool in tools %} + {%- set ns.tools_system_message = ns.tools_system_message + '\n' + (tool | tojson) %} + {%- endfor %} + {%- set ns.tools_system_message = ns.tools_system_message + tools_system_message_suffix %} +{%- else %} + {%- set ns.tools_system_message = '' %} +{%- endif %} +{%- if documents %} + {%- for document in documents %} + {%- set ns.documents_system_message = ns.documents_system_message + '\n' + (document | tojson) %} + {%- endfor %} + {%- set ns.documents_system_message = ns.documents_system_message + documents_system_message_suffix %} +{%- else %} + {%- set ns.documents_system_message = '' %} +{%- endif %} +{%- if messages[0].role == 'system' %} + {%- set ns.system_message = messages[0].content %} + {%- if tools and documents %} + {%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message + '\n\n' + ns.documents_system_message %} + {%- elif tools %} + {%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message %} + {%- elif documents %} + {%- set ns.system_message = ns.system_message + '\n\n' + ns.documents_system_message %} + {%- endif %} +{%- else %} + {%- if tools and documents %} + {%- set ns.system_message = ns.tools_system_message + '\n\n' + ns.documents_system_message %} + {%- elif tools %} + {%- set ns.system_message = ns.tools_system_message %} + {%- elif documents %} + {%- set ns.system_message = ns.documents_system_message %} + {%- endif %} +{%- endif %} +{%- if ns.system_message %} + {{- '<|start_of_role|>system<|end_of_role|>' + ns.system_message + '<|end_of_text|>\n' }} +{%- endif %} +{%- for message in messages|reverse %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if message.role == 'user' and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.last_query_index = index %} + {% break %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- set content = namespace(val='') %} + {%- if message.content is string %} + {%- set content.val = message.content %} + {%- else %} + {%- if message.content is iterable %} + {%- for entry in message.content %} + {%- if entry.type== 'text' %} + {%- if content.val != '' %} + {%- set content.val = content.val + '\n' %} + {%- endif %} + {%- set content.val = content.val + entry.text %} + {%- endif %} + {%- endfor %} + {%- endif %} + {%- endif %} + {%- if (message.role == 'user') or (message.role == 'system' and not loop.first) %} + {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val + '<|end_of_text|>\n' }} + {%- elif message.role == 'assistant' %} + {%- set thought = '' %} + {%- if message.thought is string %} + {%- set thought = message.thought %} + {%- else %} + {%- if '' in content.val %} + {%- set thought = content.val.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content.val = content.val.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and thought) %} + {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + '\n\n' + thought.strip('\n') + '\n\n\n' + content.val.lstrip('\n') }} + {%- else %} + {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val }} + {%- endif %} + {%- else %} + {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content.val) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|end_of_text|>\n' }} + {%- elif message.role == 'tool' %} + {%- if loop.first or (messages[loop.index0 - 1].role != 'tool') %} + {{- '<|start_of_role|>user<|end_of_role|>' }} + {%- endif %} + {{- '\n\n' }} + {{- content.val }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != 'tool') %} + {{- '<|end_of_text|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_of_role|>assistant<|end_of_role|>' }} + {%- if thinking is defined and thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/open_instruct/dataset_processor.py b/open_instruct/dataset_processor.py index 891984dde8..289e697b27 100644 --- a/open_instruct/dataset_processor.py +++ b/open_instruct/dataset_processor.py @@ -40,6 +40,11 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer + +# chat templates defined in jinja2 files are stored in subfolder `./chat_templates` +# CHAT_TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "chat_templates") # working dir +CHAT_TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chat_templates") # this .py's dir + logging.basicConfig(level=logging.INFO) @@ -88,35 +93,49 @@ # note we added `{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}` # because we want the template to not output eos_token if `add_generation_prompt=True` CHAT_TEMPLATES = { - "simple_concat_with_space": ( + "simple_concat_with_space": { + "type": "inline", + "template": ( "{% for message in messages %}" "{{ ' ' if not loop.first else '' }}" "{{ message['content'] }}" "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}" "{% endfor %}" - ), - "simple_concat_with_new_line": ( + ) + }, + "simple_concat_with_new_line": { + "type": "inline", + "template": ( "{% for message in messages %}" "{{ '\n' if not loop.first else '' }}" "{{ message['content'] }}" "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}" "{% endfor %}" - ), - "simple_chat": ( + ) + }, + "simple_chat": { + "type": "inline", + "template": ( "{% for message in messages %}" "{{ '\n\n' if not loop.first else '' }}" "{{ message['role'].capitalize() + ': ' + message['content'] }}" "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}" "{% endfor %}" - ), - "assistant_message_only": ( + ) + }, + "assistant_message_only": { + "type": "inline", + "template": ( "{% for message in messages %}" "{% if message['role'] == 'assistant' %}" "{{ message['content'] }}" "{% endif %}" "{% endfor %}" - ), - "zephyr": ( + ) + }, + "zephyr": { + "type": "inline", + "template": ( "{% for message in messages %}" "{% if message['role'] == 'user' %}" "{{ '<|user|>\n' + message['content'] + eos_token + '\n' }}" @@ -129,8 +148,11 @@ "{{ '<|assistant|>\n' }}" "{% endif %}" "{% endfor %}" - ), - "tulu": ( + ) + }, + "tulu": { + "type": "inline", + "template": ( "{% for message in messages %}" "{% if message['role'] == 'system' %}" "{{ '<|system|>\n' + message['content'] + '\n' }}" @@ -147,8 +169,11 @@ "{{ '<|assistant|>\n' }}" "{% endif %}" "{% endfor %}" - ), - "granite": ( + ) + }, + "granite": { + "type": "inline", + "template": ( "{% for message in messages %}" "{% if message['role'] == 'assistant' %}" "{% if not loop.last %}" @@ -163,8 +188,11 @@ "{{ '<|assistant|>\n' }}" "{% endif %}" "{% endfor %}" - ), - "granite2": ( + ) + }, + "granite2": { + "type": "inline", + "template": ( "{%- if messages[0]['role'] == 'system' %}" "{%- set system_message = messages[0]['content'] %}" "{%- set loop_messages = messages[1:] %}" @@ -201,7 +229,12 @@ "{{ '<|end_of_role|>' }}" "{%- endif %}" "{%- endfor %}" - ), + ) + }, + "granite4": { + "type": "file", + "path": os.path.join(CHAT_TEMPLATE_DIR, "ct-granite4.jinja2") + }, } # flake8: noqa diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 9cc8a005ed..8c130d3e81 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -17,7 +17,7 @@ import json import logging import math -import os +import os,sys import random import shutil import subprocess @@ -67,6 +67,8 @@ maybe_use_ai2_hf_entity, maybe_use_ai2_wandb_entity, upload_metadata_to_hf, + debug_chat_template_tokenization, + stop_debugging ) logger = get_logger(__name__) @@ -118,6 +120,12 @@ class FlatArguments: ) }, ) + # List of special tokens to be added to tokenizer, eg: + add_special_tokens: Optional[List[str]] = field( + default=None, + metadata={"help": "List of additional special tokens to add to the tokenizer"}, + ) + use_flash_attn: bool = field( default=True, metadata={"help": "Whether to use flash attention in the model training"}, @@ -714,8 +722,8 @@ def main(args: FlatArguments): else args.tokenizer_revision ) if tokenizer_revision != args.model_revision: - # Warn user if tokenizer and model use different revisions; this is an unusual - # use case. + # Warn user if tokenizer and model use different revisions; + # this is an unusual use case. warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different from the model revision `{args.model_revision}`.""" logger.warning(warning) @@ -826,6 +834,18 @@ def main(args: FlatArguments): assert num_added_tokens == 1, ( "We detected no padding token but add_special_tokens did not add one." ) + + # add special tokens if they are provided: + if args.add_special_tokens is not None: + existing_special_tokens = tokenizer.special_tokens_map.get("additional_special_tokens", []) + new_special_tokens = [t for t in args.add_special_tokens if t not in existing_special_tokens] + if new_special_tokens: + all_special_tokens = existing_special_tokens + new_special_tokens + tokenizer.add_special_tokens({"additional_special_tokens": all_special_tokens}) + if accelerator.is_main_process: + print(f"\n== Updated special tokens ({len(existing_special_tokens)} -> {len(all_special_tokens)}): {all_special_tokens}") + + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. @@ -846,7 +866,17 @@ def main(args: FlatArguments): # this will be used for encoding the training examples # and saved together with the tokenizer to be used later. if args.chat_template_name in CHAT_TEMPLATES: - tokenizer.chat_template = CHAT_TEMPLATES[args.chat_template_name] + template_config = CHAT_TEMPLATES.get(args.chat_template_name) + + accelerator.print(f"\n== template_config: {template_config}") + + if template_config["type"] == "inline": + tokenizer.chat_template = template_config["template"] + elif template_config["type"] == "file": + with open(template_config["path"], 'r') as f: + tokenizer.chat_template = f.read().strip() + else: + raise ValueError(f"Unknown chat template type: {template_config['type']}") else: try: tokenizer.chat_template = AutoTokenizer.from_pretrained( @@ -868,6 +898,18 @@ def main(args: FlatArguments): # also add bos in the chat template if not already there tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template + + if accelerator.is_main_process: + accelerator.print(f"\n **** debug_chat_template_tokenization ****") + # debug_chat_template_tokenization(tokenizer) + debug_chat_template_tokenization(tokenizer,"Default","Default","Default") + + # accelerator.wait_for_everyone() + # stop_debugging(accelerator) + + + + if args.use_lora: if args.use_qlora: model = prepare_model_for_kbit_training( @@ -928,8 +970,9 @@ def main(args: FlatArguments): ) # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + if accelerator.is_main_process: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"\nSample {index:,} of the training set: {train_dataset[index]}.") # DataLoaders creation: if args.padding_free: @@ -972,7 +1015,7 @@ def main(args: FlatArguments): }, ] - accelerator.print("Creating optimizer") + accelerator.print(f"\n **** Creating optimizer ****") if args.use_qlora: from bitsandbytes.optim import AdamW @@ -1034,14 +1077,17 @@ def main(args: FlatArguments): num_warmup_steps=num_warmup_steps, ) # Prepare everything with `accelerator`. - - accelerator.print("Preparing accelerator") + + + accelerator.print(f"\n **** Preparing accelerator ****") + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) - accelerator.print(f"{model=}") - accelerator.print(f"{accelerator.state.fsdp_plugin=}") - accelerator.print(f"{args=}") + if accelerator.is_main_process: + accelerator.print(f"\n== {model=}") + accelerator.print(f"\n== {accelerator.state.fsdp_plugin=}") + accelerator.print(f"\n== {args=}") # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil( diff --git a/open_instruct/utils.py b/open_instruct/utils.py index f4f1f9e86d..bf4e4af103 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -33,6 +33,10 @@ from rich.pretty import pprint from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser +from transformers import PreTrainedTokenizer +from typing import List, Dict, Optional +from accelerate import Accelerator + MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -895,3 +899,152 @@ def upload_metadata_to_hf( repo_type="dataset", ) os.remove("tmp.json") + + +def _get_default_messages(): + # used for testing granite4 chat template + messages = [ + {"role": "user", "content": "Who?"}, + {"role": "assistant", "content": "LLM"}, + ] + return messages + +def _get_default_tools(): + # used for testing granite4 chat template + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather for a specified city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "Name of the city" + } + }, + "required": ["city"] + } + } + }, + { + "type": "function", + "function": { + "name": "get_time", + "description": "Get the current time for a specified location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "Coordinates of the location" + } + }, + "required": ["location"] + } + } + } + ] + return tools + +def _get_default_RAG_documents(): + # used for testing granite4 chat template + documents = [ + { + "doc_id": 1, + "title": "", + "text": "From the early 12th century, French builders developed the Gothic style, marked by the use of rib vaults, pointed arches, flying buttresses, and large stained glass windows. It was used mainly in churches and cathedrals, and continued in use until the 16th century in much of Europe. Classic examples of Gothic architecture include Chartres Cathedral and Reims Cathedral in France as well as Salisbury Cathedral in England. Stained glass became a crucial element in the design of churches, which continued to use extensive wall-paintings, now almost all lost.", + "source": "" + }, + { + "doc_id": 2, + "title": "", + "text": "From long time ago, French builders developed the Gothic style, marked by the use of rib vaults, pointed arches, flying buttresses, and large stained glass windows. It was used mainly in churches and cathedrals, and continued in use until the 16th century in much of Europe. Classic examples of Gothic architecture include Chartres Cathedral and Reims Cathedral in France as well as Salisbury Cathedral in England. Stained glass became a crucial element in the design of churches, which continued to use extensive wall-paintings, now almost all lost.", + "source": "" + }, + { + "doc_id": 3, + "title": "", + "text": "From yesterday, French builders developed the Gothic style, marked by the use of rib vaults, pointed arches, flying buttresses, and large stained glass windows. It was used mainly in churches and cathedrals, and continued in use until the 16th century in much of Europe. Classic examples of Gothic architecture include Chartres Cathedral and Reims Cathedral in France as well as Salisbury Cathedral in England. Stained glass became a crucial element in the design of churches, which continued to use extensive wall-paintings, now almost all lost.", + "source": "" + } + ] + return documents + + +def debug_chat_template_tokenization( + tokenizer: PreTrainedTokenizer, + messages: Union[None, str, List[Dict[str, str]]] = None, # or "Default" or a list [...] + tools: Union[None, str, List[Dict[str, str]]] = None, # or "Default" or a list [...] + documents: Union[None, str, List[Dict[str, str]]] = None, # or "Default" or a list [...] +) -> None: + """ + Applies chat template to a sample of messages/tools/documents and tokenizes the resulting text. + + Args: + tokenizer (PreTrainedTokenizer): The tokenizer instance. + messages: Either "Default", None, or a list of {"role": ..., "content": ...} dicts. + tools: Either "Default", None, or a list of tool dicts. + documents: Either "Default", None, or a list of document dicts. + Example: + debug_chat_template_tokenization(tokenizer) + debug_chat_template_tokenization(tokenizer,"Default","Default","Default") + + """ + print( + f"\n== Tokenizer info: {len(tokenizer):,} tokens (vocab_size={tokenizer.vocab_size:,}) ==" + f"\n== Special Tokens Map (len={len(tokenizer.special_tokens_map)}):" + ) + + # special_token -> tokenID: + for name, token_str in tokenizer.special_tokens_map.items(): + token_id = tokenizer.convert_tokens_to_ids(token_str) + print(f" {name:>20}: '{token_str}' --> ID: {token_id}") + + # default messages: + if messages in ("Default", None): + messages = _get_default_messages() + print(f"\n== Messages:\n{messages}") + + # Collect optional inputs for applying chat template: + additional_inputs = {} + if tools is not None: + additional_inputs["tools"] = _get_default_tools() if tools == "Default" else tools + print(f"\n== Tools:\n{tools}") + + if documents is not None: + additional_inputs["documents"] = _get_default_RAG_documents() if documents == "Default" else documents + print(f"\n== Documents:\n{documents}") + + text = tokenizer.apply_chat_template(messages, + tokenize=False, + **additional_inputs, + ) + + tokens = tokenizer.encode(text, add_special_tokens=False) + decoded_tokens = [tokenizer.decode([t]) for t in tokens] + zipped = list(zip(tokens, decoded_tokens)) + + + print(f"\n== Final Chat Template Output:\n{text}") + print(f"\n== Tokenization Output (len={len(tokens)}):\n{tokens}\n") + for token_id, token_str in zipped: + print(f"{token_id:6d} -> `{token_str}`") + + +def stop_debugging(accelerator: Accelerator) -> None: + """ + Stops debugging and cleans up distributed resources. + Args: + accelerator (Accelerator): The accelerator instance managing distributed training. + """ + + # ensure all processes wait until the main process finishes + accelerator.wait_for_everyone() + + # clean up distributed resources + accelerator.end_training() + accelerator.free_memory() + sys.exit("== STOP DEBUGGING ==") \ No newline at end of file From 37e2aed013e411304c9a46adb32fd9d41a7d5e63 Mon Sep 17 00:00:00 2001 From: Xuan-Hong-dang Date: Wed, 18 Jun 2025 17:50:11 +0000 Subject: [PATCH 2/8] enable granite4 chattemplate --- open_instruct/utils.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/open_instruct/utils.py b/open_instruct/utils.py index bf4e4af103..fc0964c778 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -903,10 +903,21 @@ def upload_metadata_to_hf( def _get_default_messages(): # used for testing granite4 chat template + # messages = [ + # {"role": "user", "content": "Who?"}, + # {"role": "assistant", "content": "LLM"}, + # ] messages = [ - {"role": "user", "content": "Who?"}, - {"role": "assistant", "content": "LLM"}, - ] + {"role": "system", "content": "You are a weather assistant that responds with relevant function calls instead of natural language."}, + {"role": "user", "content": "What's the weather like in Bengaluru?"}, + {"role": "assistant", "content": "get_coordinates(city='Bengaluru')"}, + {"role": "system", "content": "Coordinates retrieved successfully. You can now use weather-related functions with latitude and longitude."}, + {"role": "user", "content": "Can you tell me the current temperature there?"}, + {"role": "assistant", "content": "get_current_weather(lat=12.97, lon=77.59)"}, + {"role": "system", "content": "User has requested a multi-day forecast. Switch to forecast mode."}, + {"role": "user", "content": "Actually, I need the 3-day forecast for planning a trip."}, + {"role": "assistant", "content": "get_weather_forecast(lat=12.97, lon=77.59, days=3)"}, + ] return messages def _get_default_tools(): From ee2c5a8fd70b2a69046096e435f16462c1a2f396 Mon Sep 17 00:00:00 2001 From: Xuan-Hong-dang Date: Thu, 19 Jun 2025 14:17:48 +0000 Subject: [PATCH 3/8] support g4 chat template --- .gitignore | 4 - open_instruct/dataset_processor.py | 71 ++++------- open_instruct/finetune.py | 59 ++++------ open_instruct/utils_granite.py | 182 +++++++++++++++++++++++++++++ 4 files changed, 228 insertions(+), 88 deletions(-) create mode 100644 open_instruct/utils_granite.py diff --git a/.gitignore b/.gitignore index 5318d1024f..5d5a245a9b 100644 --- a/.gitignore +++ b/.gitignore @@ -147,7 +147,3 @@ dmypy.json .idea/ .vscode .DS_Store - - -# script: -open_instruct/finetune-*.py \ No newline at end of file diff --git a/open_instruct/dataset_processor.py b/open_instruct/dataset_processor.py index 289e697b27..befee2c4e5 100644 --- a/open_instruct/dataset_processor.py +++ b/open_instruct/dataset_processor.py @@ -41,10 +41,6 @@ from transformers import PreTrainedTokenizer -# chat templates defined in jinja2 files are stored in subfolder `./chat_templates` -# CHAT_TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "chat_templates") # working dir -CHAT_TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chat_templates") # this .py's dir - logging.basicConfig(level=logging.INFO) @@ -92,50 +88,41 @@ # flake8: noqa # note we added `{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}` # because we want the template to not output eos_token if `add_generation_prompt=True` + +CHAT_TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chat_templates") +with open(os.path.join(CHAT_TEMPLATE_DIR, "ct-granite4.jinja2"), 'r', encoding='utf-8') as fid: + granite4_template = fid.read().strip() + CHAT_TEMPLATES = { - "simple_concat_with_space": { - "type": "inline", - "template": ( + "simple_concat_with_space": ( "{% for message in messages %}" "{{ ' ' if not loop.first else '' }}" "{{ message['content'] }}" "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}" "{% endfor %}" - ) - }, - "simple_concat_with_new_line": { - "type": "inline", - "template": ( + ), + "simple_concat_with_new_line": ( "{% for message in messages %}" "{{ '\n' if not loop.first else '' }}" "{{ message['content'] }}" "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}" "{% endfor %}" - ) - }, - "simple_chat": { - "type": "inline", - "template": ( + ), + "simple_chat": ( "{% for message in messages %}" "{{ '\n\n' if not loop.first else '' }}" "{{ message['role'].capitalize() + ': ' + message['content'] }}" "{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}" "{% endfor %}" - ) - }, - "assistant_message_only": { - "type": "inline", - "template": ( + ), + "assistant_message_only": ( "{% for message in messages %}" "{% if message['role'] == 'assistant' %}" "{{ message['content'] }}" "{% endif %}" "{% endfor %}" - ) - }, - "zephyr": { - "type": "inline", - "template": ( + ), + "zephyr": ( "{% for message in messages %}" "{% if message['role'] == 'user' %}" "{{ '<|user|>\n' + message['content'] + eos_token + '\n' }}" @@ -148,11 +135,8 @@ "{{ '<|assistant|>\n' }}" "{% endif %}" "{% endfor %}" - ) - }, - "tulu": { - "type": "inline", - "template": ( + ), + "tulu": ( "{% for message in messages %}" "{% if message['role'] == 'system' %}" "{{ '<|system|>\n' + message['content'] + '\n' }}" @@ -169,11 +153,8 @@ "{{ '<|assistant|>\n' }}" "{% endif %}" "{% endfor %}" - ) - }, - "granite": { - "type": "inline", - "template": ( + ), + "granite": ( "{% for message in messages %}" "{% if message['role'] == 'assistant' %}" "{% if not loop.last %}" @@ -188,11 +169,8 @@ "{{ '<|assistant|>\n' }}" "{% endif %}" "{% endfor %}" - ) - }, - "granite2": { - "type": "inline", - "template": ( + ), + "granite2": ( "{%- if messages[0]['role'] == 'system' %}" "{%- set system_message = messages[0]['content'] %}" "{%- set loop_messages = messages[1:] %}" @@ -229,13 +207,10 @@ "{{ '<|end_of_role|>' }}" "{%- endif %}" "{%- endfor %}" - ) - }, - "granite4": { - "type": "file", - "path": os.path.join(CHAT_TEMPLATE_DIR, "ct-granite4.jinja2") - }, + ), + "granite4": granite4_template } + # flake8: noqa # Performance tuning. Some rough numbers: diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 8c130d3e81..379950a72c 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -17,7 +17,7 @@ import json import logging import math -import os,sys +import os import random import shutil import subprocess @@ -67,10 +67,10 @@ maybe_use_ai2_hf_entity, maybe_use_ai2_wandb_entity, upload_metadata_to_hf, - debug_chat_template_tokenization, - stop_debugging ) +from open_instruct.utils_granite import debug_chat_template_tokenization, stop_debugging,add_special_chat_tokens + logger = get_logger(__name__) @@ -516,6 +516,12 @@ def encode_sft_example(example, tokenizer, max_seq_length): We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors. """ messages = example["messages"] + + additional_inputs = {} + for k in ["tools", "documents"]: + if k in example: + additional_inputs[k] = example[k] + if len(messages) == 0: raise ValueError("messages field is empty.") input_ids = tokenizer.apply_chat_template( @@ -526,6 +532,7 @@ def encode_sft_example(example, tokenizer, max_seq_length): truncation=True, max_length=max_seq_length, add_generation_prompt=False, + **additional_inputs, ) labels = input_ids.clone() # mask the non-assistant part for avoiding loss @@ -545,6 +552,7 @@ def encode_sft_example(example, tokenizer, max_seq_length): truncation=True, max_length=max_seq_length, add_generation_prompt=False, + **additional_inputs, ).shape[1] # next, we calculate the end index of this non-assistant message if ( @@ -562,6 +570,7 @@ def encode_sft_example(example, tokenizer, max_seq_length): truncation=True, max_length=max_seq_length, add_generation_prompt=True, + **additional_inputs, ).shape[1] else: # for the last message or the message that doesn't follow with an assistant message, @@ -574,6 +583,7 @@ def encode_sft_example(example, tokenizer, max_seq_length): truncation=True, max_length=max_seq_length, add_generation_prompt=False, + **additional_inputs, ).shape[1] # set the label to -100 for the non-assistant part labels[:, message_start_idx:message_end_idx] = -100 @@ -722,8 +732,8 @@ def main(args: FlatArguments): else args.tokenizer_revision ) if tokenizer_revision != args.model_revision: - # Warn user if tokenizer and model use different revisions; - # this is an unusual use case. + # Warn user if tokenizer and model use different revisions; this is an unusual + # use case. warning = f"""Requested tokenizer revision `{tokenizer_revision}` is different from the model revision `{args.model_revision}`.""" logger.warning(warning) @@ -837,15 +847,7 @@ def main(args: FlatArguments): # add special tokens if they are provided: if args.add_special_tokens is not None: - existing_special_tokens = tokenizer.special_tokens_map.get("additional_special_tokens", []) - new_special_tokens = [t for t in args.add_special_tokens if t not in existing_special_tokens] - if new_special_tokens: - all_special_tokens = existing_special_tokens + new_special_tokens - tokenizer.add_special_tokens({"additional_special_tokens": all_special_tokens}) - if accelerator.is_main_process: - print(f"\n== Updated special tokens ({len(existing_special_tokens)} -> {len(all_special_tokens)}): {all_special_tokens}") - - + tokenizer = add_special_chat_tokens(tokenizer,args.add_special_tokens) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. @@ -866,17 +868,7 @@ def main(args: FlatArguments): # this will be used for encoding the training examples # and saved together with the tokenizer to be used later. if args.chat_template_name in CHAT_TEMPLATES: - template_config = CHAT_TEMPLATES.get(args.chat_template_name) - - accelerator.print(f"\n== template_config: {template_config}") - - if template_config["type"] == "inline": - tokenizer.chat_template = template_config["template"] - elif template_config["type"] == "file": - with open(template_config["path"], 'r') as f: - tokenizer.chat_template = f.read().strip() - else: - raise ValueError(f"Unknown chat template type: {template_config['type']}") + tokenizer.chat_template = CHAT_TEMPLATES[args.chat_template_name] else: try: tokenizer.chat_template = AutoTokenizer.from_pretrained( @@ -901,13 +893,10 @@ def main(args: FlatArguments): if accelerator.is_main_process: accelerator.print(f"\n **** debug_chat_template_tokenization ****") - # debug_chat_template_tokenization(tokenizer) - debug_chat_template_tokenization(tokenizer,"Default","Default","Default") - - # accelerator.wait_for_everyone() - # stop_debugging(accelerator) - + # debug_chat_template_tokenization(tokenizer,"Default","Default","Default") + debug_chat_template_tokenization(tokenizer,None,None,None) + stop_debugging(accelerator) if args.use_lora: @@ -1015,7 +1004,7 @@ def main(args: FlatArguments): }, ] - accelerator.print(f"\n **** Creating optimizer ****") + accelerator.print("Creating optimizer") if args.use_qlora: from bitsandbytes.optim import AdamW @@ -1077,10 +1066,8 @@ def main(args: FlatArguments): num_warmup_steps=num_warmup_steps, ) # Prepare everything with `accelerator`. - - - accelerator.print(f"\n **** Preparing accelerator ****") - + + accelerator.print("Preparing accelerator") model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) diff --git a/open_instruct/utils_granite.py b/open_instruct/utils_granite.py new file mode 100644 index 0000000000..026b8ed0e3 --- /dev/null +++ b/open_instruct/utils_granite.py @@ -0,0 +1,182 @@ +from transformers import PreTrainedTokenizer +from typing import List, Dict, Optional, Union +from accelerate import Accelerator +import sys + + +def _get_simple_messages(): + # used for testing granite4 chat template + messages = [ + {"role": "user", "content": "Who?"}, + {"role": "assistant", "content": "LLM"}, + ] + return messages + +def _get_default_messages(): + # used for testing granite4 chat template + # messages = [ + # {"role": "user", "content": "Who?"}, + # {"role": "assistant", "content": "LLM"}, + # ] + messages = [ + {"role": "system", "content": "You are a weather assistant that responds with relevant function calls instead of natural language."}, + {"role": "user", "content": "What's the weather like in Bengaluru?"}, + {"role": "assistant", "content": "get_coordinates(city='Bengaluru')"}, + {"role": "system", "content": "Coordinates retrieved successfully. You can now use weather-related functions with latitude and longitude."}, + {"role": "user", "content": "Can you tell me the current temperature there?"}, + {"role": "assistant", "content": "get_current_weather(lat=12.97, lon=77.59)"}, + {"role": "system", "content": "User has requested a multi-day forecast. Switch to forecast mode."}, + {"role": "user", "content": "Actually, I need the 3-day forecast for planning a trip."}, + {"role": "assistant", "content": "get_weather_forecast(lat=12.97, lon=77.59, days=3)"}, + ] + return messages + +def _get_default_tools(): + # used for testing granite4 chat template + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather for a specified city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "Name of the city" + } + }, + "required": ["city"] + } + } + }, + { + "type": "function", + "function": { + "name": "get_time", + "description": "Get the current time for a specified location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "Coordinates of the location" + } + }, + "required": ["location"] + } + } + } + ] + return tools + +def _get_default_RAG_documents(): + # used for testing granite4 chat template + documents = [ + { + "doc_id": 1, + "title": "", + "text": "From the early 12th century, French builders developed the Gothic style, marked by the use of rib vaults, pointed arches, flying buttresses, and large stained glass windows. It was used mainly in churches and cathedrals, and continued in use until the 16th century in much of Europe. Classic examples of Gothic architecture include Chartres Cathedral and Reims Cathedral in France as well as Salisbury Cathedral in England. Stained glass became a crucial element in the design of churches, which continued to use extensive wall-paintings, now almost all lost.", + "source": "" + }, + { + "doc_id": 2, + "title": "", + "text": "From long time ago, French builders developed the Gothic style, marked by the use of rib vaults, pointed arches, flying buttresses, and large stained glass windows. It was used mainly in churches and cathedrals, and continued in use until the 16th century in much of Europe. Classic examples of Gothic architecture include Chartres Cathedral and Reims Cathedral in France as well as Salisbury Cathedral in England. Stained glass became a crucial element in the design of churches, which continued to use extensive wall-paintings, now almost all lost.", + "source": "" + }, + { + "doc_id": 3, + "title": "", + "text": "From yesterday, French builders developed the Gothic style, marked by the use of rib vaults, pointed arches, flying buttresses, and large stained glass windows. It was used mainly in churches and cathedrals, and continued in use until the 16th century in much of Europe. Classic examples of Gothic architecture include Chartres Cathedral and Reims Cathedral in France as well as Salisbury Cathedral in England. Stained glass became a crucial element in the design of churches, which continued to use extensive wall-paintings, now almost all lost.", + "source": "" + } + ] + return documents + +def add_special_chat_tokens(tokenizer, add_special_tokens:list): + existing_special_tokens = tokenizer.special_tokens_map.get("additional_special_tokens", []) + new_special_tokens = [t for t in add_special_tokens if t not in existing_special_tokens] + if new_special_tokens: + all_special_tokens = existing_special_tokens + new_special_tokens + tokenizer.add_special_tokens({"additional_special_tokens": all_special_tokens}) + + return tokenizer + + + +def debug_chat_template_tokenization( + tokenizer: PreTrainedTokenizer, + messages: Union[None, str, List[Dict[str, str]]] = None, # or "Default" or a list [...] + tools: Union[None, str, List[Dict[str, str]]] = None, # or "Default" or a list [...] + documents: Union[None, str, List[Dict[str, str]]] = None, # or "Default" or a list [...] +) -> None: + """ + Applies chat template to a sample of messages/tools/documents and tokenizes the resulting text. + + Args: + tokenizer (PreTrainedTokenizer): The tokenizer instance. + messages: Either "Default", None, or a list of {"role": ..., "content": ...} dicts. + tools: Either "Default", None, or a list of tool dicts. + documents: Either "Default", None, or a list of document dicts. + Example: + debug_chat_template_tokenization(tokenizer) + debug_chat_template_tokenization(tokenizer,"Default","Default","Default") + + """ + print( + f"\n== Tokenizer info: {len(tokenizer):,} tokens (vocab_size={tokenizer.vocab_size:,}) ==" + f"\n== Special Tokens Map (len={len(tokenizer.special_tokens_map)}):" + ) + + # special_token -> tokenID: + for name, token_str in tokenizer.special_tokens_map.items(): + token_id = tokenizer.convert_tokens_to_ids(token_str) + print(f" {name:>20}: '{token_str}' --> ID: {token_id}") + + # default messages: + if messages in ("Default", None): + messages = _get_simple_messages() if messages is None else _get_default_messages() + print(f"\n== Messages:\n{messages}") + + # Collect optional inputs for applying chat template: + additional_inputs = {} + if tools is not None: + additional_inputs["tools"] = _get_default_tools() if tools == "Default" else tools + print(f"\n== Tools:\n{tools}") + + if documents is not None: + additional_inputs["documents"] = _get_default_RAG_documents() if documents == "Default" else documents + print(f"\n== Documents:\n{documents}") + + text = tokenizer.apply_chat_template(messages, + tokenize=False, + **additional_inputs, + ) + + tokens = tokenizer.encode(text, add_special_tokens=False) + decoded_tokens = [tokenizer.decode([t]) for t in tokens] + zipped = list(zip(tokens, decoded_tokens)) + + + print(f"\n== Chat Template Output:\n{text}") + print(f"\n== Tokenization Output (len={len(tokens)}):\n{tokens}\n") + for token_id, token_str in zipped: + print(f"{token_id:6d} -> `{token_str}`") + + +def stop_debugging(accelerator: Accelerator) -> None: + """ + Stops debugging and cleans up distributed resources. + Args: + accelerator (Accelerator): The accelerator instance managing distributed training. + """ + + # ensure all processes wait until the main process finishes + accelerator.wait_for_everyone() + + # clean up distributed resources + accelerator.end_training() + accelerator.free_memory() + sys.exit("== STOP DEBUGGING ==") \ No newline at end of file From 8a5c5905987eb5a32f1e7f4adeb5694cbe57effd Mon Sep 17 00:00:00 2001 From: Xuan-Hong-dang Date: Thu, 19 Jun 2025 14:26:20 +0000 Subject: [PATCH 4/8] support granite 4 chat template --- open_instruct/finetune.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 379950a72c..667f826716 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -36,7 +36,7 @@ from accelerate.utils import InitProcessGroupKwargs, set_seed from datasets import load_dataset from huggingface_hub import HfApi -from padding_free_collator import TensorDataCollatorWithFlattening +from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training from torch.utils.data import DataLoader from tqdm.auto import tqdm @@ -890,15 +890,6 @@ def main(args: FlatArguments): # also add bos in the chat template if not already there tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template - - if accelerator.is_main_process: - accelerator.print(f"\n **** debug_chat_template_tokenization ****") - # debug_chat_template_tokenization(tokenizer,"Default","Default","Default") - debug_chat_template_tokenization(tokenizer,None,None,None) - - stop_debugging(accelerator) - - if args.use_lora: if args.use_qlora: model = prepare_model_for_kbit_training( From e78bda07833525a6f6b717aca1bd5a5c75fabc9c Mon Sep 17 00:00:00 2001 From: Xuan-Hong-dang Date: Thu, 19 Jun 2025 14:54:36 +0000 Subject: [PATCH 5/8] support granite4 chat template --- open_instruct/finetune.py | 56 +++++++++++++++++++----------------- open_instruct/model_utils.py | 3 +- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 667f826716..3f0cb120d9 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -69,8 +69,6 @@ upload_metadata_to_hf, ) -from open_instruct.utils_granite import debug_chat_template_tokenization, stop_debugging,add_special_chat_tokens - logger = get_logger(__name__) @@ -120,12 +118,6 @@ class FlatArguments: ) }, ) - # List of special tokens to be added to tokenizer, eg: - add_special_tokens: Optional[List[str]] = field( - default=None, - metadata={"help": "List of additional special tokens to add to the tokenizer"}, - ) - use_flash_attn: bool = field( default=True, metadata={"help": "Whether to use flash attention in the model training"}, @@ -186,7 +178,7 @@ class FlatArguments: ) train_file: Optional[str] = field( default=None, - metadata={"help": "The input training data file (a json/jsonl file)."}, + metadata={"help": "The input training data file (a json/jsonl/parquet file or directory)."}, ) max_train_samples: Optional[int] = field( default=None, @@ -459,8 +451,29 @@ def __post_init__(self): ) else: if self.train_file is not None: - extension = self.train_file.split(".")[-1] - assert extension in ["json", "jsonl", "parquet"], ( + if os.path.isdir(self.train_file): + # just assume they + self.train_file = [ + os.path.join(self.train_file, x) + for x in os.listdir(self.train_file) + ] + self.train_file_type = [ + x.split(".")[-1] for x in self.train_file + ] + self.train_file_type = [ + x for x in self.train_file_type + if x in ["json", "jsonl", "parquet"] + ] + self.train_file_type = list(set(self.train_file_type)) # unique + # assume the directory cannot mix types + self.train_file_type = ( + None if len(self.train_file_type) == 0 else + self.train_file_type[0] + ) + else: + self.train_file_type = self.train_file.split(".")[-1] + + assert self.train_file_type in ["json", "jsonl", "parquet"], ( "`train_file` should be a json or a jsonl or parquet file." ) if ( @@ -696,12 +709,9 @@ def main(args: FlatArguments): dataset_args = {} if args.train_file is not None: data_files["train"] = args.train_file - data_type = "json" - if args.train_file.endswith('.parquet'): - data_type = "parquet" with accelerator.main_process_first(): raw_datasets = load_dataset( - data_type, + args.train_file_type, data_files=data_files, **dataset_args, ) @@ -844,10 +854,6 @@ def main(args: FlatArguments): assert num_added_tokens == 1, ( "We detected no padding token but add_special_tokens did not add one." ) - - # add special tokens if they are provided: - if args.add_special_tokens is not None: - tokenizer = add_special_chat_tokens(tokenizer,args.add_special_tokens) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. @@ -950,9 +956,8 @@ def main(args: FlatArguments): ) # Log a few random samples from the training set: - if accelerator.is_main_process: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"\nSample {index:,} of the training set: {train_dataset[index]}.") + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # DataLoaders creation: if args.padding_free: @@ -1062,10 +1067,9 @@ def main(args: FlatArguments): model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) - if accelerator.is_main_process: - accelerator.print(f"\n== {model=}") - accelerator.print(f"\n== {accelerator.state.fsdp_plugin=}") - accelerator.print(f"\n== {args=}") + accelerator.print(f"{model=}") + accelerator.print(f"{accelerator.state.fsdp_plugin=}") + accelerator.print(f"{args=}") # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil( diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index de15d55b60..f283511767 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -379,6 +379,7 @@ def save_with_accelerate( output_dir: str, use_lora: bool = False, model_attribute_to_save: Optional[str] = None, + safe_serialization: bool = True ) -> None: """`model_attribute_to_save` is for used to save PPO's policy instead of the full model""" # set the generation config to an empty setting to be safe. @@ -420,7 +421,7 @@ def save_with_accelerate( is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=state_dict, - safe_serialization=False, + safe_serialization=safe_serialization, ) if accelerator.is_main_process: From 7c1089697e8a5804dd43762e98574837e8caa777 Mon Sep 17 00:00:00 2001 From: Xuan-Hong-dang Date: Thu, 19 Jun 2025 17:52:37 +0000 Subject: [PATCH 6/8] g4 script update --- open_instruct/finetune.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 3f0cb120d9..8781e2b8c1 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -69,6 +69,8 @@ upload_metadata_to_hf, ) +from open_instruct.utils_granite import debug_chat_template_tokenization, stop_debugging,add_special_chat_tokens + logger = get_logger(__name__) @@ -118,6 +120,12 @@ class FlatArguments: ) }, ) + # List of special tokens to be added to tokenizer, eg: + add_special_tokens: Optional[List[str]] = field( + default=None, + metadata={"help": "List of additional special tokens to add to the tokenizer"}, + ) + use_flash_attn: bool = field( default=True, metadata={"help": "Whether to use flash attention in the model training"}, @@ -854,6 +862,10 @@ def main(args: FlatArguments): assert num_added_tokens == 1, ( "We detected no padding token but add_special_tokens did not add one." ) + + # add special tokens if they are provided: + if args.add_special_tokens is not None: + tokenizer = add_special_chat_tokens(tokenizer,args.add_special_tokens) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. @@ -956,9 +968,10 @@ def main(args: FlatArguments): ) # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - + if accelerator.is_main_process: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"\nSample {index:,} of the training set: {train_dataset[index]}.") + # DataLoaders creation: if args.padding_free: accelerator.print("Using padding-free collation") @@ -1067,9 +1080,10 @@ def main(args: FlatArguments): model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) - accelerator.print(f"{model=}") - accelerator.print(f"{accelerator.state.fsdp_plugin=}") - accelerator.print(f"{args=}") + if accelerator.is_main_process: + accelerator.print(f"\n== {model=}") + accelerator.print(f"\n== {accelerator.state.fsdp_plugin=}") + accelerator.print(f"\n== {args=}") # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil( From 5f6c29c57ab25d0e717be84af134fdaeac8660e9 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 20 Jun 2025 11:32:06 +0000 Subject: [PATCH 7/8] update template with gabe changes and remove think --- .../chat_templates/ct-granite4.jinja2 | 58 +++++++++++-------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/open_instruct/chat_templates/ct-granite4.jinja2 b/open_instruct/chat_templates/ct-granite4.jinja2 index e653cc5b08..33f7e5b469 100644 --- a/open_instruct/chat_templates/ct-granite4.jinja2 +++ b/open_instruct/chat_templates/ct-granite4.jinja2 @@ -26,7 +26,18 @@ {%- set ns.documents_system_message = '' %} {%- endif %} {%- if messages[0].role == 'system' %} - {%- set ns.system_message = messages[0].content %} + {%- if messages[0].content is string %} + {%- set ns.system_message = messages[0].content %} + {%- elif messages[0].content is iterable %} + {%- for entry in messages[0].content %} + {%- if entry.type== 'text' %} + {%- if ns.system_message != '' %} + {%- set ns.system_message = ns.system_message + '\n' %} + {%- endif %} + {%- set ns.system_message = ns.system_message + entry.text %} + {%- endif %} + {%- endfor %} + {%- endif %} {%- if tools and documents %} {%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message + '\n\n' + ns.documents_system_message %} {%- elif tools %} @@ -48,9 +59,26 @@ {%- endif %} {%- for message in messages|reverse %} {%- set index = (messages|length - 1) - loop.index0 %} - {%- if message.role == 'user' and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} - {%- set ns.last_query_index = index %} - {% break %} + {%- if message.role == 'user' %} + {%- set content = namespace(val='') %} + {%- if message.content is string %} + {%- set content.val = message.content %} + {%- else %} + {%- if message.content is iterable %} + {%- for entry in message.content %} + {%- if entry.type== 'text' %} + {%- if content.val != '' %} + {%- set content.val = content.val + '\n' %} + {%- endif %} + {%- set content.val = content.val + entry.text %} + {%- endif %} + {%- endfor %} + {%- endif %} + {%- endif %} + {%-if not(content.val.startswith('') and content.val.endswith('')) %} + {%- set ns.last_query_index = index %} + {% break %} + {%- endif %} {%- endif %} {%- endfor %} {%- for message in messages %} @@ -72,24 +100,7 @@ {%- if (message.role == 'user') or (message.role == 'system' and not loop.first) %} {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val + '<|end_of_text|>\n' }} {%- elif message.role == 'assistant' %} - {%- set thought = '' %} - {%- if message.thought is string %} - {%- set thought = message.thought %} - {%- else %} - {%- if '' in content.val %} - {%- set thought = content.val.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} - {%- set content.val = content.val.split('')[-1].lstrip('\n') %} - {%- endif %} - {%- endif %} - {%- if loop.index0 > ns.last_query_index %} - {%- if loop.last or (not loop.last and thought) %} - {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + '\n\n' + thought.strip('\n') + '\n\n\n' + content.val.lstrip('\n') }} - {%- else %} - {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val }} - {%- endif %} - {%- else %} - {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val }} - {%- endif %} + {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val }} {%- if message.tool_calls %} {%- for tool_call in message.tool_calls %} {%- if (loop.first and content.val) or (not loop.first) %} @@ -124,7 +135,4 @@ {%- endfor %} {%- if add_generation_prompt %} {{- '<|start_of_role|>assistant<|end_of_role|>' }} - {%- if thinking is defined and thinking is false %} - {{- '\n\n\n\n' }} - {%- endif %} {%- endif %} \ No newline at end of file From a9f420d16ec12e1eb4682b79d2aae1b187819b87 Mon Sep 17 00:00:00 2001 From: Xuan-Hong-dang Date: Thu, 26 Jun 2025 01:57:54 +0000 Subject: [PATCH 8/8] script to SFT granite4-light on w/o g4 chat template --- .../chat_templates/ct-granite4-v01.jinja2 | 130 +++++++++++++ open_instruct/finetune.py | 37 +++- open_instruct/utils.py | 184 ++---------------- open_instruct/utils_granite.py | 33 ++-- 4 files changed, 193 insertions(+), 191 deletions(-) create mode 100644 open_instruct/chat_templates/ct-granite4-v01.jinja2 diff --git a/open_instruct/chat_templates/ct-granite4-v01.jinja2 b/open_instruct/chat_templates/ct-granite4-v01.jinja2 new file mode 100644 index 0000000000..e653cc5b08 --- /dev/null +++ b/open_instruct/chat_templates/ct-granite4-v01.jinja2 @@ -0,0 +1,130 @@ +{%- set tools_system_message_prefix = 'You are a helpful assistant with access to the following tools. You may call one or more tools to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n' %} +{%- set tools_system_message_suffix = '\n\n\nFor each tool call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.' %} +{%- set documents_system_message_prefix = 'You are a helpful assistant with access to the following documents. You may use one or more documents to assist with the user query.\n\nYou are given a list of documents within XML tags:\n' %} +{%- set documents_system_message_suffix = '\n\n\nWrite the response to the user\'s input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.' %} +{%- if available_tools is defined and available_tools %} + {%- set tools = available_tools %} +{%- endif %} +{%- set ns = namespace(tools_system_message=tools_system_message_prefix, + documents_system_message=documents_system_message_prefix, + system_message='', + last_query_index=messages|length - 1) %} +{%- if tools %} + {%- for tool in tools %} + {%- set ns.tools_system_message = ns.tools_system_message + '\n' + (tool | tojson) %} + {%- endfor %} + {%- set ns.tools_system_message = ns.tools_system_message + tools_system_message_suffix %} +{%- else %} + {%- set ns.tools_system_message = '' %} +{%- endif %} +{%- if documents %} + {%- for document in documents %} + {%- set ns.documents_system_message = ns.documents_system_message + '\n' + (document | tojson) %} + {%- endfor %} + {%- set ns.documents_system_message = ns.documents_system_message + documents_system_message_suffix %} +{%- else %} + {%- set ns.documents_system_message = '' %} +{%- endif %} +{%- if messages[0].role == 'system' %} + {%- set ns.system_message = messages[0].content %} + {%- if tools and documents %} + {%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message + '\n\n' + ns.documents_system_message %} + {%- elif tools %} + {%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message %} + {%- elif documents %} + {%- set ns.system_message = ns.system_message + '\n\n' + ns.documents_system_message %} + {%- endif %} +{%- else %} + {%- if tools and documents %} + {%- set ns.system_message = ns.tools_system_message + '\n\n' + ns.documents_system_message %} + {%- elif tools %} + {%- set ns.system_message = ns.tools_system_message %} + {%- elif documents %} + {%- set ns.system_message = ns.documents_system_message %} + {%- endif %} +{%- endif %} +{%- if ns.system_message %} + {{- '<|start_of_role|>system<|end_of_role|>' + ns.system_message + '<|end_of_text|>\n' }} +{%- endif %} +{%- for message in messages|reverse %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if message.role == 'user' and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.last_query_index = index %} + {% break %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- set content = namespace(val='') %} + {%- if message.content is string %} + {%- set content.val = message.content %} + {%- else %} + {%- if message.content is iterable %} + {%- for entry in message.content %} + {%- if entry.type== 'text' %} + {%- if content.val != '' %} + {%- set content.val = content.val + '\n' %} + {%- endif %} + {%- set content.val = content.val + entry.text %} + {%- endif %} + {%- endfor %} + {%- endif %} + {%- endif %} + {%- if (message.role == 'user') or (message.role == 'system' and not loop.first) %} + {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val + '<|end_of_text|>\n' }} + {%- elif message.role == 'assistant' %} + {%- set thought = '' %} + {%- if message.thought is string %} + {%- set thought = message.thought %} + {%- else %} + {%- if '' in content.val %} + {%- set thought = content.val.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content.val = content.val.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and thought) %} + {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + '\n\n' + thought.strip('\n') + '\n\n\n' + content.val.lstrip('\n') }} + {%- else %} + {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val }} + {%- endif %} + {%- else %} + {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content.val) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|end_of_text|>\n' }} + {%- elif message.role == 'tool' %} + {%- if loop.first or (messages[loop.index0 - 1].role != 'tool') %} + {{- '<|start_of_role|>user<|end_of_role|>' }} + {%- endif %} + {{- '\n\n' }} + {{- content.val }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != 'tool') %} + {{- '<|end_of_text|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_of_role|>assistant<|end_of_role|>' }} + {%- if thinking is defined and thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 8781e2b8c1..4bc3848758 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -27,6 +27,7 @@ from functools import partial from typing import List, Optional, Union +import pandas as pd import datasets import deepspeed import torch @@ -701,8 +702,9 @@ def main(args: FlatArguments): configs=args.dataset_config_name, splits=["train"], save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, - columns_to_keep=["messages"], + columns_to_keep=["messages","tools","documents"], ) + print(f"####### Loaded datasets {raw_datasets}") elif args.dataset_mixer_list is not None: # mixing datasets via config raw_datasets = get_datasets( @@ -710,19 +712,28 @@ def main(args: FlatArguments): configs=args.dataset_config_name, splits=["train"], save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None, - columns_to_keep=["messages"], + columns_to_keep=["messages","tools","documents"], ) + print(f"####### Loaded datasets {raw_datasets}") else: data_files = {} dataset_args = {} if args.train_file is not None: data_files["train"] = args.train_file - with accelerator.main_process_first(): - raw_datasets = load_dataset( - args.train_file_type, - data_files=data_files, - **dataset_args, - ) + with accelerator.main_process_first(): + try: + raw_datasets = load_dataset( + args.train_file_type, + data_files=data_files, + **dataset_args, + ) + except: + if isinstance(args.train_file, list) and len(args.train_files) > 1 or isinstance(args.train_file, dict): + print("Passing a single file when reading with pandas. Please provide a single jsonl file.") + else: + df = pd.read_json(args.train_file,lines=True,orient='records') + dataset = datasets.Dataset.from_pandas(df) + raw_datasets = dataset.train_test_split(test_size=1) # Load pretrained model and tokenizer if args.config_name: @@ -908,6 +919,13 @@ def main(args: FlatArguments): # also add bos in the chat template if not already there tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template + # check chat template: + if accelerator.is_main_process: + accelerator.print(f"\n **** debug_chat_template_tokenization ****") + debug_chat_template_tokenization(tokenizer,None,None,None) # naive sample + debug_chat_template_tokenization(tokenizer,"Default","Default","Default") + # stop_debugging(accelerator) + if args.use_lora: if args.use_qlora: model = prepare_model_for_kbit_training( @@ -1519,5 +1537,6 @@ def main(args: FlatArguments): parser = ArgumentParserPlus((FlatArguments)) args = parser.parse() if os.environ["RANK"] == "0": - print(f"{args=}") + print(f"\n*** Input args:") + print(json.dumps(vars(args), indent=4)) main(args) diff --git a/open_instruct/utils.py b/open_instruct/utils.py index fc0964c778..89760e31d2 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -23,19 +23,20 @@ import time from dataclasses import dataclass from typing import Any, List, NewType, Optional, Tuple, Union +import pandas as pd import requests from accelerate.logging import get_logger -from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk +from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk, Dataset from datasets.builder import DatasetGenerationError from dateutil import parser from huggingface_hub import HfApi from rich.pretty import pprint from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser -from transformers import PreTrainedTokenizer -from typing import List, Dict, Optional -from accelerate import Accelerator + +from typing import List, Optional + MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -249,7 +250,12 @@ def get_datasets( for split in splits: # if dataset ends with .json or .jsonl, load from file if ds.endswith(".json") or ds.endswith(".jsonl"): - dataset = load_dataset("json", data_files=ds, split=split) + try: + dataset = load_dataset("json", data_files=ds, split=split) + except: + df = pd.read_json(ds,lines=True,orient='records') + dataset = Dataset.from_pandas(df) + print(f'###### Dataset loaded:\n {dataset}') else: try: # Try first if dataset on a Hub repo @@ -309,8 +315,10 @@ def get_datasets( ): dataset = dataset.map(convert_rejection_samples_to_messages, num_proc=10) + # if id not in dataset, create it as ds-{index} if "id" not in dataset.column_names: + print(f"#### 'id' not found in {dataset.column_names} \n. Adding new column") id_col = [f"{ds}_{i}" for i in range(len(dataset))] dataset = dataset.add_column("id", id_col) @@ -321,13 +329,17 @@ def get_datasets( # if add_source_col, add that column if add_source_col: + print(f"####### Adding source column") source_col = [ds] * len(dataset) dataset = dataset.add_column("source", source_col) + print(f"######## Done adding columns to obtain\n {dataset}") # for cols in columns_to_keep, if one is not present, add "None" to the column for col in columns_to_keep: if col not in dataset.column_names: + print(f"####### Adding {col} column which is in {columns_to_keep}") dataset = dataset.add_column(col, [None] * len(dataset)) + print(f"######## Done adding columns in columns_to_keep to obtain\n {dataset}") # add tag to the dataset corresponding to where it was sourced from, for if "train" in split: @@ -389,11 +401,13 @@ def get_datasets( # optional save if save_data_dir: + print(f"####### Saving datasets to {save_data_dir}") for split in raw_datasets: raw_datasets[split].to_json(save_data_dir + f"mixed_ds_{split}.json") if not keep_ids: # remove id column + print(f"Removing id column, keep_ids: {keep_ids}") if len(raw_train_datasets) > 0: if "id" in raw_datasets["train"].column_names: raw_datasets["train"] = raw_datasets["train"].remove_columns("id") @@ -900,162 +914,4 @@ def upload_metadata_to_hf( ) os.remove("tmp.json") - -def _get_default_messages(): - # used for testing granite4 chat template - # messages = [ - # {"role": "user", "content": "Who?"}, - # {"role": "assistant", "content": "LLM"}, - # ] - messages = [ - {"role": "system", "content": "You are a weather assistant that responds with relevant function calls instead of natural language."}, - {"role": "user", "content": "What's the weather like in Bengaluru?"}, - {"role": "assistant", "content": "get_coordinates(city='Bengaluru')"}, - {"role": "system", "content": "Coordinates retrieved successfully. You can now use weather-related functions with latitude and longitude."}, - {"role": "user", "content": "Can you tell me the current temperature there?"}, - {"role": "assistant", "content": "get_current_weather(lat=12.97, lon=77.59)"}, - {"role": "system", "content": "User has requested a multi-day forecast. Switch to forecast mode."}, - {"role": "user", "content": "Actually, I need the 3-day forecast for planning a trip."}, - {"role": "assistant", "content": "get_weather_forecast(lat=12.97, lon=77.59, days=3)"}, - ] - return messages - -def _get_default_tools(): - # used for testing granite4 chat template - tools = [ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather for a specified city.", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "Name of the city" - } - }, - "required": ["city"] - } - } - }, - { - "type": "function", - "function": { - "name": "get_time", - "description": "Get the current time for a specified location.", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "Coordinates of the location" - } - }, - "required": ["location"] - } - } - } - ] - return tools - -def _get_default_RAG_documents(): - # used for testing granite4 chat template - documents = [ - { - "doc_id": 1, - "title": "", - "text": "From the early 12th century, French builders developed the Gothic style, marked by the use of rib vaults, pointed arches, flying buttresses, and large stained glass windows. It was used mainly in churches and cathedrals, and continued in use until the 16th century in much of Europe. Classic examples of Gothic architecture include Chartres Cathedral and Reims Cathedral in France as well as Salisbury Cathedral in England. Stained glass became a crucial element in the design of churches, which continued to use extensive wall-paintings, now almost all lost.", - "source": "" - }, - { - "doc_id": 2, - "title": "", - "text": "From long time ago, French builders developed the Gothic style, marked by the use of rib vaults, pointed arches, flying buttresses, and large stained glass windows. It was used mainly in churches and cathedrals, and continued in use until the 16th century in much of Europe. Classic examples of Gothic architecture include Chartres Cathedral and Reims Cathedral in France as well as Salisbury Cathedral in England. Stained glass became a crucial element in the design of churches, which continued to use extensive wall-paintings, now almost all lost.", - "source": "" - }, - { - "doc_id": 3, - "title": "", - "text": "From yesterday, French builders developed the Gothic style, marked by the use of rib vaults, pointed arches, flying buttresses, and large stained glass windows. It was used mainly in churches and cathedrals, and continued in use until the 16th century in much of Europe. Classic examples of Gothic architecture include Chartres Cathedral and Reims Cathedral in France as well as Salisbury Cathedral in England. Stained glass became a crucial element in the design of churches, which continued to use extensive wall-paintings, now almost all lost.", - "source": "" - } - ] - return documents - - -def debug_chat_template_tokenization( - tokenizer: PreTrainedTokenizer, - messages: Union[None, str, List[Dict[str, str]]] = None, # or "Default" or a list [...] - tools: Union[None, str, List[Dict[str, str]]] = None, # or "Default" or a list [...] - documents: Union[None, str, List[Dict[str, str]]] = None, # or "Default" or a list [...] -) -> None: - """ - Applies chat template to a sample of messages/tools/documents and tokenizes the resulting text. - - Args: - tokenizer (PreTrainedTokenizer): The tokenizer instance. - messages: Either "Default", None, or a list of {"role": ..., "content": ...} dicts. - tools: Either "Default", None, or a list of tool dicts. - documents: Either "Default", None, or a list of document dicts. - Example: - debug_chat_template_tokenization(tokenizer) - debug_chat_template_tokenization(tokenizer,"Default","Default","Default") - - """ - print( - f"\n== Tokenizer info: {len(tokenizer):,} tokens (vocab_size={tokenizer.vocab_size:,}) ==" - f"\n== Special Tokens Map (len={len(tokenizer.special_tokens_map)}):" - ) - - # special_token -> tokenID: - for name, token_str in tokenizer.special_tokens_map.items(): - token_id = tokenizer.convert_tokens_to_ids(token_str) - print(f" {name:>20}: '{token_str}' --> ID: {token_id}") - - # default messages: - if messages in ("Default", None): - messages = _get_default_messages() - print(f"\n== Messages:\n{messages}") - - # Collect optional inputs for applying chat template: - additional_inputs = {} - if tools is not None: - additional_inputs["tools"] = _get_default_tools() if tools == "Default" else tools - print(f"\n== Tools:\n{tools}") - - if documents is not None: - additional_inputs["documents"] = _get_default_RAG_documents() if documents == "Default" else documents - print(f"\n== Documents:\n{documents}") - - text = tokenizer.apply_chat_template(messages, - tokenize=False, - **additional_inputs, - ) - - tokens = tokenizer.encode(text, add_special_tokens=False) - decoded_tokens = [tokenizer.decode([t]) for t in tokens] - zipped = list(zip(tokens, decoded_tokens)) - - - print(f"\n== Final Chat Template Output:\n{text}") - print(f"\n== Tokenization Output (len={len(tokens)}):\n{tokens}\n") - for token_id, token_str in zipped: - print(f"{token_id:6d} -> `{token_str}`") - - -def stop_debugging(accelerator: Accelerator) -> None: - """ - Stops debugging and cleans up distributed resources. - Args: - accelerator (Accelerator): The accelerator instance managing distributed training. - """ - - # ensure all processes wait until the main process finishes - accelerator.wait_for_everyone() - - # clean up distributed resources - accelerator.end_training() - accelerator.free_memory() - sys.exit("== STOP DEBUGGING ==") \ No newline at end of file + \ No newline at end of file diff --git a/open_instruct/utils_granite.py b/open_instruct/utils_granite.py index 026b8ed0e3..81ff9747d4 100644 --- a/open_instruct/utils_granite.py +++ b/open_instruct/utils_granite.py @@ -21,13 +21,10 @@ def _get_default_messages(): messages = [ {"role": "system", "content": "You are a weather assistant that responds with relevant function calls instead of natural language."}, {"role": "user", "content": "What's the weather like in Bengaluru?"}, - {"role": "assistant", "content": "get_coordinates(city='Bengaluru')"}, + {"role": "assistant", "content": " Need to determine coordinates for Bengaluru get_coordinates(city='Bengaluru')"}, {"role": "system", "content": "Coordinates retrieved successfully. You can now use weather-related functions with latitude and longitude."}, - {"role": "user", "content": "Can you tell me the current temperature there?"}, - {"role": "assistant", "content": "get_current_weather(lat=12.97, lon=77.59)"}, - {"role": "system", "content": "User has requested a multi-day forecast. Switch to forecast mode."}, {"role": "user", "content": "Actually, I need the 3-day forecast for planning a trip."}, - {"role": "assistant", "content": "get_weather_forecast(lat=12.97, lon=77.59, days=3)"}, + {"role": "assistant", "content": " get_weather_forecast(lat=12.97, lon=77.59, days=3)", "thought": "User wants forecast. Need to call forecast API"}, ] return messages @@ -77,19 +74,13 @@ def _get_default_RAG_documents(): { "doc_id": 1, "title": "", - "text": "From the early 12th century, French builders developed the Gothic style, marked by the use of rib vaults, pointed arches, flying buttresses, and large stained glass windows. It was used mainly in churches and cathedrals, and continued in use until the 16th century in much of Europe. Classic examples of Gothic architecture include Chartres Cathedral and Reims Cathedral in France as well as Salisbury Cathedral in England. Stained glass became a crucial element in the design of churches, which continued to use extensive wall-paintings, now almost all lost.", + "text": "Doc 1", "source": "" }, { "doc_id": 2, "title": "", - "text": "From long time ago, French builders developed the Gothic style, marked by the use of rib vaults, pointed arches, flying buttresses, and large stained glass windows. It was used mainly in churches and cathedrals, and continued in use until the 16th century in much of Europe. Classic examples of Gothic architecture include Chartres Cathedral and Reims Cathedral in France as well as Salisbury Cathedral in England. Stained glass became a crucial element in the design of churches, which continued to use extensive wall-paintings, now almost all lost.", - "source": "" - }, - { - "doc_id": 3, - "title": "", - "text": "From yesterday, French builders developed the Gothic style, marked by the use of rib vaults, pointed arches, flying buttresses, and large stained glass windows. It was used mainly in churches and cathedrals, and continued in use until the 16th century in much of Europe. Classic examples of Gothic architecture include Chartres Cathedral and Reims Cathedral in France as well as Salisbury Cathedral in England. Stained glass became a crucial element in the design of churches, which continued to use extensive wall-paintings, now almost all lost.", + "text": "Doc 2", "source": "" } ] @@ -166,7 +157,7 @@ def debug_chat_template_tokenization( print(f"{token_id:6d} -> `{token_str}`") -def stop_debugging(accelerator: Accelerator) -> None: +def stop_debugging(accelerator: Accelerator = None,msg:str=None) -> None: """ Stops debugging and cleans up distributed resources. Args: @@ -174,9 +165,15 @@ def stop_debugging(accelerator: Accelerator) -> None: """ # ensure all processes wait until the main process finishes - accelerator.wait_for_everyone() - # clean up distributed resources - accelerator.end_training() - accelerator.free_memory() + if msg is not None: + if accelerator is not None and accelerator.is_local_main_process: + accelerator.print(f"\n\n** {msg} **\n") + elif accelerator is None: + print(f"\n\n** {msg} **\n") + + if accelerator is not None: + accelerator.wait_for_everyone() + accelerator.end_training() + accelerator.free_memory() sys.exit("== STOP DEBUGGING ==") \ No newline at end of file