diff --git a/.gitignore b/.gitignore
index 423d441d13..5d5a245a9b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -146,4 +146,4 @@ dmypy.json
.idea/
.vscode
-.DS_Store
\ No newline at end of file
+.DS_Store
diff --git a/open_instruct/chat_templates/ct-granite4-v01.jinja2 b/open_instruct/chat_templates/ct-granite4-v01.jinja2
new file mode 100644
index 0000000000..e653cc5b08
--- /dev/null
+++ b/open_instruct/chat_templates/ct-granite4-v01.jinja2
@@ -0,0 +1,130 @@
+{%- set tools_system_message_prefix = 'You are a helpful assistant with access to the following tools. You may call one or more tools to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n' %}
+{%- set tools_system_message_suffix = '\n\n\nFor each tool call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.' %}
+{%- set documents_system_message_prefix = 'You are a helpful assistant with access to the following documents. You may use one or more documents to assist with the user query.\n\nYou are given a list of documents within XML tags:\n' %}
+{%- set documents_system_message_suffix = '\n\n\nWrite the response to the user\'s input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.' %}
+{%- if available_tools is defined and available_tools %}
+ {%- set tools = available_tools %}
+{%- endif %}
+{%- set ns = namespace(tools_system_message=tools_system_message_prefix,
+ documents_system_message=documents_system_message_prefix,
+ system_message='',
+ last_query_index=messages|length - 1) %}
+{%- if tools %}
+ {%- for tool in tools %}
+ {%- set ns.tools_system_message = ns.tools_system_message + '\n' + (tool | tojson) %}
+ {%- endfor %}
+ {%- set ns.tools_system_message = ns.tools_system_message + tools_system_message_suffix %}
+{%- else %}
+ {%- set ns.tools_system_message = '' %}
+{%- endif %}
+{%- if documents %}
+ {%- for document in documents %}
+ {%- set ns.documents_system_message = ns.documents_system_message + '\n' + (document | tojson) %}
+ {%- endfor %}
+ {%- set ns.documents_system_message = ns.documents_system_message + documents_system_message_suffix %}
+{%- else %}
+ {%- set ns.documents_system_message = '' %}
+{%- endif %}
+{%- if messages[0].role == 'system' %}
+ {%- set ns.system_message = messages[0].content %}
+ {%- if tools and documents %}
+ {%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message + '\n\n' + ns.documents_system_message %}
+ {%- elif tools %}
+ {%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message %}
+ {%- elif documents %}
+ {%- set ns.system_message = ns.system_message + '\n\n' + ns.documents_system_message %}
+ {%- endif %}
+{%- else %}
+ {%- if tools and documents %}
+ {%- set ns.system_message = ns.tools_system_message + '\n\n' + ns.documents_system_message %}
+ {%- elif tools %}
+ {%- set ns.system_message = ns.tools_system_message %}
+ {%- elif documents %}
+ {%- set ns.system_message = ns.documents_system_message %}
+ {%- endif %}
+{%- endif %}
+{%- if ns.system_message %}
+ {{- '<|start_of_role|>system<|end_of_role|>' + ns.system_message + '<|end_of_text|>\n' }}
+{%- endif %}
+{%- for message in messages|reverse %}
+ {%- set index = (messages|length - 1) - loop.index0 %}
+ {%- if message.role == 'user' and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}
+ {%- set ns.last_query_index = index %}
+ {% break %}
+ {%- endif %}
+{%- endfor %}
+{%- for message in messages %}
+ {%- set content = namespace(val='') %}
+ {%- if message.content is string %}
+ {%- set content.val = message.content %}
+ {%- else %}
+ {%- if message.content is iterable %}
+ {%- for entry in message.content %}
+ {%- if entry.type== 'text' %}
+ {%- if content.val != '' %}
+ {%- set content.val = content.val + '\n' %}
+ {%- endif %}
+ {%- set content.val = content.val + entry.text %}
+ {%- endif %}
+ {%- endfor %}
+ {%- endif %}
+ {%- endif %}
+ {%- if (message.role == 'user') or (message.role == 'system' and not loop.first) %}
+ {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val + '<|end_of_text|>\n' }}
+ {%- elif message.role == 'assistant' %}
+ {%- set thought = '' %}
+ {%- if message.thought is string %}
+ {%- set thought = message.thought %}
+ {%- else %}
+ {%- if '' in content.val %}
+ {%- set thought = content.val.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %}
+ {%- set content.val = content.val.split('')[-1].lstrip('\n') %}
+ {%- endif %}
+ {%- endif %}
+ {%- if loop.index0 > ns.last_query_index %}
+ {%- if loop.last or (not loop.last and thought) %}
+ {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + '\n\n' + thought.strip('\n') + '\n\n\n' + content.val.lstrip('\n') }}
+ {%- else %}
+ {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val }}
+ {%- endif %}
+ {%- else %}
+ {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val }}
+ {%- endif %}
+ {%- if message.tool_calls %}
+ {%- for tool_call in message.tool_calls %}
+ {%- if (loop.first and content.val) or (not loop.first) %}
+ {{- '\n' }}
+ {%- endif %}
+ {%- if tool_call.function %}
+ {%- set tool_call = tool_call.function %}
+ {%- endif %}
+ {{- '\n{"name": "' }}
+ {{- tool_call.name }}
+ {{- '", "arguments": ' }}
+ {%- if tool_call.arguments is string %}
+ {{- tool_call.arguments }}
+ {%- else %}
+ {{- tool_call.arguments | tojson }}
+ {%- endif %}
+ {{- '}\n' }}
+ {%- endfor %}
+ {%- endif %}
+ {{- '<|end_of_text|>\n' }}
+ {%- elif message.role == 'tool' %}
+ {%- if loop.first or (messages[loop.index0 - 1].role != 'tool') %}
+ {{- '<|start_of_role|>user<|end_of_role|>' }}
+ {%- endif %}
+ {{- '\n\n' }}
+ {{- content.val }}
+ {{- '\n' }}
+ {%- if loop.last or (messages[loop.index0 + 1].role != 'tool') %}
+ {{- '<|end_of_text|>\n' }}
+ {%- endif %}
+ {%- endif %}
+{%- endfor %}
+{%- if add_generation_prompt %}
+ {{- '<|start_of_role|>assistant<|end_of_role|>' }}
+ {%- if thinking is defined and thinking is false %}
+ {{- '\n\n\n\n' }}
+ {%- endif %}
+{%- endif %}
\ No newline at end of file
diff --git a/open_instruct/chat_templates/ct-granite4.jinja2 b/open_instruct/chat_templates/ct-granite4.jinja2
new file mode 100644
index 0000000000..33f7e5b469
--- /dev/null
+++ b/open_instruct/chat_templates/ct-granite4.jinja2
@@ -0,0 +1,138 @@
+{%- set tools_system_message_prefix = 'You are a helpful assistant with access to the following tools. You may call one or more tools to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n' %}
+{%- set tools_system_message_suffix = '\n\n\nFor each tool call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.' %}
+{%- set documents_system_message_prefix = 'You are a helpful assistant with access to the following documents. You may use one or more documents to assist with the user query.\n\nYou are given a list of documents within XML tags:\n' %}
+{%- set documents_system_message_suffix = '\n\n\nWrite the response to the user\'s input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data.' %}
+{%- if available_tools is defined and available_tools %}
+ {%- set tools = available_tools %}
+{%- endif %}
+{%- set ns = namespace(tools_system_message=tools_system_message_prefix,
+ documents_system_message=documents_system_message_prefix,
+ system_message='',
+ last_query_index=messages|length - 1) %}
+{%- if tools %}
+ {%- for tool in tools %}
+ {%- set ns.tools_system_message = ns.tools_system_message + '\n' + (tool | tojson) %}
+ {%- endfor %}
+ {%- set ns.tools_system_message = ns.tools_system_message + tools_system_message_suffix %}
+{%- else %}
+ {%- set ns.tools_system_message = '' %}
+{%- endif %}
+{%- if documents %}
+ {%- for document in documents %}
+ {%- set ns.documents_system_message = ns.documents_system_message + '\n' + (document | tojson) %}
+ {%- endfor %}
+ {%- set ns.documents_system_message = ns.documents_system_message + documents_system_message_suffix %}
+{%- else %}
+ {%- set ns.documents_system_message = '' %}
+{%- endif %}
+{%- if messages[0].role == 'system' %}
+ {%- if messages[0].content is string %}
+ {%- set ns.system_message = messages[0].content %}
+ {%- elif messages[0].content is iterable %}
+ {%- for entry in messages[0].content %}
+ {%- if entry.type== 'text' %}
+ {%- if ns.system_message != '' %}
+ {%- set ns.system_message = ns.system_message + '\n' %}
+ {%- endif %}
+ {%- set ns.system_message = ns.system_message + entry.text %}
+ {%- endif %}
+ {%- endfor %}
+ {%- endif %}
+ {%- if tools and documents %}
+ {%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message + '\n\n' + ns.documents_system_message %}
+ {%- elif tools %}
+ {%- set ns.system_message = ns.system_message + '\n\n' + ns.tools_system_message %}
+ {%- elif documents %}
+ {%- set ns.system_message = ns.system_message + '\n\n' + ns.documents_system_message %}
+ {%- endif %}
+{%- else %}
+ {%- if tools and documents %}
+ {%- set ns.system_message = ns.tools_system_message + '\n\n' + ns.documents_system_message %}
+ {%- elif tools %}
+ {%- set ns.system_message = ns.tools_system_message %}
+ {%- elif documents %}
+ {%- set ns.system_message = ns.documents_system_message %}
+ {%- endif %}
+{%- endif %}
+{%- if ns.system_message %}
+ {{- '<|start_of_role|>system<|end_of_role|>' + ns.system_message + '<|end_of_text|>\n' }}
+{%- endif %}
+{%- for message in messages|reverse %}
+ {%- set index = (messages|length - 1) - loop.index0 %}
+ {%- if message.role == 'user' %}
+ {%- set content = namespace(val='') %}
+ {%- if message.content is string %}
+ {%- set content.val = message.content %}
+ {%- else %}
+ {%- if message.content is iterable %}
+ {%- for entry in message.content %}
+ {%- if entry.type== 'text' %}
+ {%- if content.val != '' %}
+ {%- set content.val = content.val + '\n' %}
+ {%- endif %}
+ {%- set content.val = content.val + entry.text %}
+ {%- endif %}
+ {%- endfor %}
+ {%- endif %}
+ {%- endif %}
+ {%-if not(content.val.startswith('') and content.val.endswith('')) %}
+ {%- set ns.last_query_index = index %}
+ {% break %}
+ {%- endif %}
+ {%- endif %}
+{%- endfor %}
+{%- for message in messages %}
+ {%- set content = namespace(val='') %}
+ {%- if message.content is string %}
+ {%- set content.val = message.content %}
+ {%- else %}
+ {%- if message.content is iterable %}
+ {%- for entry in message.content %}
+ {%- if entry.type== 'text' %}
+ {%- if content.val != '' %}
+ {%- set content.val = content.val + '\n' %}
+ {%- endif %}
+ {%- set content.val = content.val + entry.text %}
+ {%- endif %}
+ {%- endfor %}
+ {%- endif %}
+ {%- endif %}
+ {%- if (message.role == 'user') or (message.role == 'system' and not loop.first) %}
+ {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val + '<|end_of_text|>\n' }}
+ {%- elif message.role == 'assistant' %}
+ {{- '<|start_of_role|>' + message.role + '<|end_of_role|>' + content.val }}
+ {%- if message.tool_calls %}
+ {%- for tool_call in message.tool_calls %}
+ {%- if (loop.first and content.val) or (not loop.first) %}
+ {{- '\n' }}
+ {%- endif %}
+ {%- if tool_call.function %}
+ {%- set tool_call = tool_call.function %}
+ {%- endif %}
+ {{- '\n{"name": "' }}
+ {{- tool_call.name }}
+ {{- '", "arguments": ' }}
+ {%- if tool_call.arguments is string %}
+ {{- tool_call.arguments }}
+ {%- else %}
+ {{- tool_call.arguments | tojson }}
+ {%- endif %}
+ {{- '}\n' }}
+ {%- endfor %}
+ {%- endif %}
+ {{- '<|end_of_text|>\n' }}
+ {%- elif message.role == 'tool' %}
+ {%- if loop.first or (messages[loop.index0 - 1].role != 'tool') %}
+ {{- '<|start_of_role|>user<|end_of_role|>' }}
+ {%- endif %}
+ {{- '\n\n' }}
+ {{- content.val }}
+ {{- '\n' }}
+ {%- if loop.last or (messages[loop.index0 + 1].role != 'tool') %}
+ {{- '<|end_of_text|>\n' }}
+ {%- endif %}
+ {%- endif %}
+{%- endfor %}
+{%- if add_generation_prompt %}
+ {{- '<|start_of_role|>assistant<|end_of_role|>' }}
+{%- endif %}
\ No newline at end of file
diff --git a/open_instruct/dataset_processor.py b/open_instruct/dataset_processor.py
index 891984dde8..befee2c4e5 100644
--- a/open_instruct/dataset_processor.py
+++ b/open_instruct/dataset_processor.py
@@ -40,6 +40,7 @@
from tqdm import tqdm
from transformers import PreTrainedTokenizer
+
logging.basicConfig(level=logging.INFO)
@@ -87,6 +88,11 @@
# flake8: noqa
# note we added `{% if loop.last and not add_generation_prompt %}{{ eos_token }}{% endif %}`
# because we want the template to not output eos_token if `add_generation_prompt=True`
+
+CHAT_TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "chat_templates")
+with open(os.path.join(CHAT_TEMPLATE_DIR, "ct-granite4.jinja2"), 'r', encoding='utf-8') as fid:
+ granite4_template = fid.read().strip()
+
CHAT_TEMPLATES = {
"simple_concat_with_space": (
"{% for message in messages %}"
@@ -202,7 +208,9 @@
"{%- endif %}"
"{%- endfor %}"
),
+ "granite4": granite4_template
}
+
# flake8: noqa
# Performance tuning. Some rough numbers:
diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py
index 9cc8a005ed..4bc3848758 100644
--- a/open_instruct/finetune.py
+++ b/open_instruct/finetune.py
@@ -27,6 +27,7 @@
from functools import partial
from typing import List, Optional, Union
+import pandas as pd
import datasets
import deepspeed
import torch
@@ -36,7 +37,7 @@
from accelerate.utils import InitProcessGroupKwargs, set_seed
from datasets import load_dataset
from huggingface_hub import HfApi
-from padding_free_collator import TensorDataCollatorWithFlattening
+from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
@@ -69,6 +70,8 @@
upload_metadata_to_hf,
)
+from open_instruct.utils_granite import debug_chat_template_tokenization, stop_debugging,add_special_chat_tokens
+
logger = get_logger(__name__)
@@ -118,6 +121,12 @@ class FlatArguments:
)
},
)
+ # List of special tokens to be added to tokenizer, eg:
+ add_special_tokens: Optional[List[str]] = field(
+ default=None,
+ metadata={"help": "List of additional special tokens to add to the tokenizer"},
+ )
+
use_flash_attn: bool = field(
default=True,
metadata={"help": "Whether to use flash attention in the model training"},
@@ -178,7 +187,7 @@ class FlatArguments:
)
train_file: Optional[str] = field(
default=None,
- metadata={"help": "The input training data file (a json/jsonl file)."},
+ metadata={"help": "The input training data file (a json/jsonl/parquet file or directory)."},
)
max_train_samples: Optional[int] = field(
default=None,
@@ -451,8 +460,29 @@ def __post_init__(self):
)
else:
if self.train_file is not None:
- extension = self.train_file.split(".")[-1]
- assert extension in ["json", "jsonl", "parquet"], (
+ if os.path.isdir(self.train_file):
+ # just assume they
+ self.train_file = [
+ os.path.join(self.train_file, x)
+ for x in os.listdir(self.train_file)
+ ]
+ self.train_file_type = [
+ x.split(".")[-1] for x in self.train_file
+ ]
+ self.train_file_type = [
+ x for x in self.train_file_type
+ if x in ["json", "jsonl", "parquet"]
+ ]
+ self.train_file_type = list(set(self.train_file_type)) # unique
+ # assume the directory cannot mix types
+ self.train_file_type = (
+ None if len(self.train_file_type) == 0 else
+ self.train_file_type[0]
+ )
+ else:
+ self.train_file_type = self.train_file.split(".")[-1]
+
+ assert self.train_file_type in ["json", "jsonl", "parquet"], (
"`train_file` should be a json or a jsonl or parquet file."
)
if (
@@ -508,6 +538,12 @@ def encode_sft_example(example, tokenizer, max_seq_length):
We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors.
"""
messages = example["messages"]
+
+ additional_inputs = {}
+ for k in ["tools", "documents"]:
+ if k in example:
+ additional_inputs[k] = example[k]
+
if len(messages) == 0:
raise ValueError("messages field is empty.")
input_ids = tokenizer.apply_chat_template(
@@ -518,6 +554,7 @@ def encode_sft_example(example, tokenizer, max_seq_length):
truncation=True,
max_length=max_seq_length,
add_generation_prompt=False,
+ **additional_inputs,
)
labels = input_ids.clone()
# mask the non-assistant part for avoiding loss
@@ -537,6 +574,7 @@ def encode_sft_example(example, tokenizer, max_seq_length):
truncation=True,
max_length=max_seq_length,
add_generation_prompt=False,
+ **additional_inputs,
).shape[1]
# next, we calculate the end index of this non-assistant message
if (
@@ -554,6 +592,7 @@ def encode_sft_example(example, tokenizer, max_seq_length):
truncation=True,
max_length=max_seq_length,
add_generation_prompt=True,
+ **additional_inputs,
).shape[1]
else:
# for the last message or the message that doesn't follow with an assistant message,
@@ -566,6 +605,7 @@ def encode_sft_example(example, tokenizer, max_seq_length):
truncation=True,
max_length=max_seq_length,
add_generation_prompt=False,
+ **additional_inputs,
).shape[1]
# set the label to -100 for the non-assistant part
labels[:, message_start_idx:message_end_idx] = -100
@@ -662,8 +702,9 @@ def main(args: FlatArguments):
configs=args.dataset_config_name,
splits=["train"],
save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None,
- columns_to_keep=["messages"],
+ columns_to_keep=["messages","tools","documents"],
)
+ print(f"####### Loaded datasets {raw_datasets}")
elif args.dataset_mixer_list is not None:
# mixing datasets via config
raw_datasets = get_datasets(
@@ -671,22 +712,28 @@ def main(args: FlatArguments):
configs=args.dataset_config_name,
splits=["train"],
save_data_dir=args.dataset_mix_dir if accelerator.is_main_process else None,
- columns_to_keep=["messages"],
+ columns_to_keep=["messages","tools","documents"],
)
+ print(f"####### Loaded datasets {raw_datasets}")
else:
data_files = {}
dataset_args = {}
if args.train_file is not None:
data_files["train"] = args.train_file
- data_type = "json"
- if args.train_file.endswith('.parquet'):
- data_type = "parquet"
- with accelerator.main_process_first():
- raw_datasets = load_dataset(
- data_type,
- data_files=data_files,
- **dataset_args,
- )
+ with accelerator.main_process_first():
+ try:
+ raw_datasets = load_dataset(
+ args.train_file_type,
+ data_files=data_files,
+ **dataset_args,
+ )
+ except:
+ if isinstance(args.train_file, list) and len(args.train_files) > 1 or isinstance(args.train_file, dict):
+ print("Passing a single file when reading with pandas. Please provide a single jsonl file.")
+ else:
+ df = pd.read_json(args.train_file,lines=True,orient='records')
+ dataset = datasets.Dataset.from_pandas(df)
+ raw_datasets = dataset.train_test_split(test_size=1)
# Load pretrained model and tokenizer
if args.config_name:
@@ -826,6 +873,10 @@ def main(args: FlatArguments):
assert num_added_tokens == 1, (
"We detected no padding token but add_special_tokens did not add one."
)
+
+ # add special tokens if they are provided:
+ if args.add_special_tokens is not None:
+ tokenizer = add_special_chat_tokens(tokenizer,args.add_special_tokens)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
@@ -868,6 +919,13 @@ def main(args: FlatArguments):
# also add bos in the chat template if not already there
tokenizer.chat_template = "{{ bos_token }}" + tokenizer.chat_template
+ # check chat template:
+ if accelerator.is_main_process:
+ accelerator.print(f"\n **** debug_chat_template_tokenization ****")
+ debug_chat_template_tokenization(tokenizer,None,None,None) # naive sample
+ debug_chat_template_tokenization(tokenizer,"Default","Default","Default")
+ # stop_debugging(accelerator)
+
if args.use_lora:
if args.use_qlora:
model = prepare_model_for_kbit_training(
@@ -928,9 +986,10 @@ def main(args: FlatArguments):
)
# Log a few random samples from the training set:
- for index in random.sample(range(len(train_dataset)), 3):
- logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
-
+ if accelerator.is_main_process:
+ for index in random.sample(range(len(train_dataset)), 3):
+ logger.info(f"\nSample {index:,} of the training set: {train_dataset[index]}.")
+
# DataLoaders creation:
if args.padding_free:
accelerator.print("Using padding-free collation")
@@ -1039,9 +1098,10 @@ def main(args: FlatArguments):
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
- accelerator.print(f"{model=}")
- accelerator.print(f"{accelerator.state.fsdp_plugin=}")
- accelerator.print(f"{args=}")
+ if accelerator.is_main_process:
+ accelerator.print(f"\n== {model=}")
+ accelerator.print(f"\n== {accelerator.state.fsdp_plugin=}")
+ accelerator.print(f"\n== {args=}")
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
@@ -1477,5 +1537,6 @@ def main(args: FlatArguments):
parser = ArgumentParserPlus((FlatArguments))
args = parser.parse()
if os.environ["RANK"] == "0":
- print(f"{args=}")
+ print(f"\n*** Input args:")
+ print(json.dumps(vars(args), indent=4))
main(args)
diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py
index de15d55b60..f283511767 100644
--- a/open_instruct/model_utils.py
+++ b/open_instruct/model_utils.py
@@ -379,6 +379,7 @@ def save_with_accelerate(
output_dir: str,
use_lora: bool = False,
model_attribute_to_save: Optional[str] = None,
+ safe_serialization: bool = True
) -> None:
"""`model_attribute_to_save` is for used to save PPO's policy instead of the full model"""
# set the generation config to an empty setting to be safe.
@@ -420,7 +421,7 @@ def save_with_accelerate(
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=state_dict,
- safe_serialization=False,
+ safe_serialization=safe_serialization,
)
if accelerator.is_main_process:
diff --git a/open_instruct/utils.py b/open_instruct/utils.py
index f4f1f9e86d..89760e31d2 100644
--- a/open_instruct/utils.py
+++ b/open_instruct/utils.py
@@ -23,16 +23,21 @@
import time
from dataclasses import dataclass
from typing import Any, List, NewType, Optional, Tuple, Union
+import pandas as pd
import requests
from accelerate.logging import get_logger
-from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
+from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk, Dataset
from datasets.builder import DatasetGenerationError
from dateutil import parser
from huggingface_hub import HfApi
from rich.pretty import pprint
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
+
+from typing import List, Optional
+
+
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@@ -245,7 +250,12 @@ def get_datasets(
for split in splits:
# if dataset ends with .json or .jsonl, load from file
if ds.endswith(".json") or ds.endswith(".jsonl"):
- dataset = load_dataset("json", data_files=ds, split=split)
+ try:
+ dataset = load_dataset("json", data_files=ds, split=split)
+ except:
+ df = pd.read_json(ds,lines=True,orient='records')
+ dataset = Dataset.from_pandas(df)
+ print(f'###### Dataset loaded:\n {dataset}')
else:
try:
# Try first if dataset on a Hub repo
@@ -305,8 +315,10 @@ def get_datasets(
):
dataset = dataset.map(convert_rejection_samples_to_messages, num_proc=10)
+
# if id not in dataset, create it as ds-{index}
if "id" not in dataset.column_names:
+ print(f"#### 'id' not found in {dataset.column_names} \n. Adding new column")
id_col = [f"{ds}_{i}" for i in range(len(dataset))]
dataset = dataset.add_column("id", id_col)
@@ -317,13 +329,17 @@ def get_datasets(
# if add_source_col, add that column
if add_source_col:
+ print(f"####### Adding source column")
source_col = [ds] * len(dataset)
dataset = dataset.add_column("source", source_col)
+ print(f"######## Done adding columns to obtain\n {dataset}")
# for cols in columns_to_keep, if one is not present, add "None" to the column
for col in columns_to_keep:
if col not in dataset.column_names:
+ print(f"####### Adding {col} column which is in {columns_to_keep}")
dataset = dataset.add_column(col, [None] * len(dataset))
+ print(f"######## Done adding columns in columns_to_keep to obtain\n {dataset}")
# add tag to the dataset corresponding to where it was sourced from, for
if "train" in split:
@@ -385,11 +401,13 @@ def get_datasets(
# optional save
if save_data_dir:
+ print(f"####### Saving datasets to {save_data_dir}")
for split in raw_datasets:
raw_datasets[split].to_json(save_data_dir + f"mixed_ds_{split}.json")
if not keep_ids:
# remove id column
+ print(f"Removing id column, keep_ids: {keep_ids}")
if len(raw_train_datasets) > 0:
if "id" in raw_datasets["train"].column_names:
raw_datasets["train"] = raw_datasets["train"].remove_columns("id")
@@ -895,3 +913,5 @@ def upload_metadata_to_hf(
repo_type="dataset",
)
os.remove("tmp.json")
+
+
\ No newline at end of file
diff --git a/open_instruct/utils_granite.py b/open_instruct/utils_granite.py
new file mode 100644
index 0000000000..81ff9747d4
--- /dev/null
+++ b/open_instruct/utils_granite.py
@@ -0,0 +1,179 @@
+from transformers import PreTrainedTokenizer
+from typing import List, Dict, Optional, Union
+from accelerate import Accelerator
+import sys
+
+
+def _get_simple_messages():
+ # used for testing granite4 chat template
+ messages = [
+ {"role": "user", "content": "Who?"},
+ {"role": "assistant", "content": "LLM"},
+ ]
+ return messages
+
+def _get_default_messages():
+ # used for testing granite4 chat template
+ # messages = [
+ # {"role": "user", "content": "Who?"},
+ # {"role": "assistant", "content": "LLM"},
+ # ]
+ messages = [
+ {"role": "system", "content": "You are a weather assistant that responds with relevant function calls instead of natural language."},
+ {"role": "user", "content": "What's the weather like in Bengaluru?"},
+ {"role": "assistant", "content": " Need to determine coordinates for Bengaluru get_coordinates(city='Bengaluru')"},
+ {"role": "system", "content": "Coordinates retrieved successfully. You can now use weather-related functions with latitude and longitude."},
+ {"role": "user", "content": "Actually, I need the 3-day forecast for planning a trip."},
+ {"role": "assistant", "content": " get_weather_forecast(lat=12.97, lon=77.59, days=3)", "thought": "User wants forecast. Need to call forecast API"},
+ ]
+ return messages
+
+def _get_default_tools():
+ # used for testing granite4 chat template
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "description": "Get the current weather for a specified city.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "city": {
+ "type": "string",
+ "description": "Name of the city"
+ }
+ },
+ "required": ["city"]
+ }
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_time",
+ "description": "Get the current time for a specified location.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "Coordinates of the location"
+ }
+ },
+ "required": ["location"]
+ }
+ }
+ }
+ ]
+ return tools
+
+def _get_default_RAG_documents():
+ # used for testing granite4 chat template
+ documents = [
+ {
+ "doc_id": 1,
+ "title": "",
+ "text": "Doc 1",
+ "source": ""
+ },
+ {
+ "doc_id": 2,
+ "title": "",
+ "text": "Doc 2",
+ "source": ""
+ }
+ ]
+ return documents
+
+def add_special_chat_tokens(tokenizer, add_special_tokens:list):
+ existing_special_tokens = tokenizer.special_tokens_map.get("additional_special_tokens", [])
+ new_special_tokens = [t for t in add_special_tokens if t not in existing_special_tokens]
+ if new_special_tokens:
+ all_special_tokens = existing_special_tokens + new_special_tokens
+ tokenizer.add_special_tokens({"additional_special_tokens": all_special_tokens})
+
+ return tokenizer
+
+
+
+def debug_chat_template_tokenization(
+ tokenizer: PreTrainedTokenizer,
+ messages: Union[None, str, List[Dict[str, str]]] = None, # or "Default" or a list [...]
+ tools: Union[None, str, List[Dict[str, str]]] = None, # or "Default" or a list [...]
+ documents: Union[None, str, List[Dict[str, str]]] = None, # or "Default" or a list [...]
+) -> None:
+ """
+ Applies chat template to a sample of messages/tools/documents and tokenizes the resulting text.
+
+ Args:
+ tokenizer (PreTrainedTokenizer): The tokenizer instance.
+ messages: Either "Default", None, or a list of {"role": ..., "content": ...} dicts.
+ tools: Either "Default", None, or a list of tool dicts.
+ documents: Either "Default", None, or a list of document dicts.
+ Example:
+ debug_chat_template_tokenization(tokenizer)
+ debug_chat_template_tokenization(tokenizer,"Default","Default","Default")
+
+ """
+ print(
+ f"\n== Tokenizer info: {len(tokenizer):,} tokens (vocab_size={tokenizer.vocab_size:,}) =="
+ f"\n== Special Tokens Map (len={len(tokenizer.special_tokens_map)}):"
+ )
+
+ # special_token -> tokenID:
+ for name, token_str in tokenizer.special_tokens_map.items():
+ token_id = tokenizer.convert_tokens_to_ids(token_str)
+ print(f" {name:>20}: '{token_str}' --> ID: {token_id}")
+
+ # default messages:
+ if messages in ("Default", None):
+ messages = _get_simple_messages() if messages is None else _get_default_messages()
+ print(f"\n== Messages:\n{messages}")
+
+ # Collect optional inputs for applying chat template:
+ additional_inputs = {}
+ if tools is not None:
+ additional_inputs["tools"] = _get_default_tools() if tools == "Default" else tools
+ print(f"\n== Tools:\n{tools}")
+
+ if documents is not None:
+ additional_inputs["documents"] = _get_default_RAG_documents() if documents == "Default" else documents
+ print(f"\n== Documents:\n{documents}")
+
+ text = tokenizer.apply_chat_template(messages,
+ tokenize=False,
+ **additional_inputs,
+ )
+
+ tokens = tokenizer.encode(text, add_special_tokens=False)
+ decoded_tokens = [tokenizer.decode([t]) for t in tokens]
+ zipped = list(zip(tokens, decoded_tokens))
+
+
+ print(f"\n== Chat Template Output:\n{text}")
+ print(f"\n== Tokenization Output (len={len(tokens)}):\n{tokens}\n")
+ for token_id, token_str in zipped:
+ print(f"{token_id:6d} -> `{token_str}`")
+
+
+def stop_debugging(accelerator: Accelerator = None,msg:str=None) -> None:
+ """
+ Stops debugging and cleans up distributed resources.
+ Args:
+ accelerator (Accelerator): The accelerator instance managing distributed training.
+ """
+
+ # ensure all processes wait until the main process finishes
+
+ if msg is not None:
+ if accelerator is not None and accelerator.is_local_main_process:
+ accelerator.print(f"\n\n** {msg} **\n")
+ elif accelerator is None:
+ print(f"\n\n** {msg} **\n")
+
+ if accelerator is not None:
+ accelerator.wait_for_everyone()
+ accelerator.end_training()
+ accelerator.free_memory()
+ sys.exit("== STOP DEBUGGING ==")
\ No newline at end of file