Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 63 additions & 8 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AutoTokenizer,
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForSeq2Seq,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
Expand Down Expand Up @@ -169,6 +170,12 @@ def __init__(
"You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument."
)

# TODO: think about this error handling and if we want to enforce seq2seq collator
if not packing and formatting_func is None and dataset_text_field is None and data_collator is not None and not isinstance(data_collator, DataCollatorForSeq2Seq):
raise ValueError(
"If no formatting_func / dataset_text_field provided, the data_collator should be a `DataCollatorForSeq2Seq` object"
)

if is_peft_available() and peft_config is not None:
if not isinstance(peft_config, PeftConfig):
raise ValueError(
Expand Down Expand Up @@ -244,15 +251,15 @@ def make_inputs_require_grad(module, input, output):
# check if dataset has ChatML format or instruction format and is supported
# if not stays #None
formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer)


requires_input_output_keys = False
if not packing:
if dataset_text_field is None and formatting_func is None:
raise ValueError(
"You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument."
)

requires_input_output_keys = (dataset_text_field is None and formatting_func is None)
if data_collator is None:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Fall back to the appropriate collator type based on the input_output_keys
data_collator = (DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)
if requires_input_output_keys
else DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False))

# Pre-process the datasets only once per node. The remaining processes will use the cache.
with PartialState().local_main_process_first():
Expand All @@ -269,6 +276,7 @@ def make_inputs_require_grad(module, input, output):
num_of_sequences,
chars_per_token,
remove_unused_columns=args.remove_unused_columns if args is not None else True,
requires_input_output_keys=requires_input_output_keys,
**dataset_kwargs,
)
if eval_dataset is not None:
Expand Down Expand Up @@ -365,6 +373,7 @@ def _prepare_dataset(
num_of_sequences,
chars_per_token,
remove_unused_columns=True,
requires_input_output_keys=False,
append_concat_token=True,
add_special_tokens=True,
):
Expand All @@ -384,6 +393,7 @@ def _prepare_dataset(
formatting_func,
add_special_tokens,
remove_unused_columns,
requires_input_output_keys,
)

else:
Expand All @@ -408,10 +418,43 @@ def _prepare_non_packed_dataloader(
formatting_func=None,
add_special_tokens=True,
remove_unused_columns=True,
requires_input_output_keys=False,
):
use_formatting_func = formatting_func is not None and dataset_text_field is None
self._dataset_sanity_checked = False

# TODO : fix how EOS tokens are handled
# Inspired from https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L266
def tokenize_input_output(element):

eos_token = ''
if add_special_tokens:
eos_token = tokenizer.eos_token
tokenizer.eos_token=None

new_source = []
for (input_element, output_element) in zip(element['input'], element['output']):
if not input_element.endswith((' ', '\n', '\t')) and not output_element.startswith((' ', '\n', '\t')):
new_source.append(input_element + ' ' + output_element + eos_token)
else:
new_source.append(input_element + output_element + eos_token)

tokenized_example = tokenizer(new_source, max_length=max_seq_length, truncation=True, padding=False, add_special_tokens=add_special_tokens)
input_ids = tokenized_example.input_ids
labels = input_ids

# mask the prompt part for avoiding loss
tokenized_prompt = tokenizer(element['input'], max_length=max_seq_length, truncation=True, add_special_tokens=add_special_tokens)

new_labels = [([-100] * len(tokenized_instance)) + label_instance[len(tokenized_instance):] for tokenized_instance,label_instance in zip(tokenized_prompt.input_ids, labels) ]
attention_mask = tokenized_example.attention_mask

return {
'input_ids': input_ids,
'labels': new_labels,
'attention_mask': attention_mask,
}

# Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
def tokenize(element):
outputs = tokenizer(
Expand Down Expand Up @@ -444,8 +487,20 @@ def tokenize(element):
f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns."
)

if requires_input_output_keys:
if "input" in dataset.column_names and "output" in dataset.column_names:
# TODO: if we execute this input path, it is expected that we are using a seq2seq
# collator. If that is the case, the tokenizer should had a pad_token; this is set
# to eos automatically if it's unset and no tokenizer is provided, but we should
# properly handle if a tokenizer with no padding token is given.
tokenize_func = tokenize_input_output
else:
raise KeyError("Missing input / output keys")
else:
tokenize_func = tokenize

tokenized_dataset = dataset.map(
tokenize,
tokenize_func,
batched=True,
remove_columns=dataset.column_names if remove_unused_columns else None,
num_proc=self.dataset_num_proc,
Expand Down