33# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
44import os
55import logging
6- from typing import Union , Optional
6+ from typing import Union , Optional , Dict
77
88import torch
99import deepspeed
3030from lmflow .utils .constants import (
3131 LMFLOW_LORA_TARGET_MODULES_MAPPING
3232)
33+ from lmflow .args import ModelArguments
3334
3435
3536logger = logging .getLogger (__name__ )
5152class HFModelMixin (BaseModel ):
5253 def __init__ (
5354 self ,
54- model_args ,
55+ model_args : ModelArguments ,
5556 do_train : bool ,
5657 ds_config = None ,
5758 device : Optional [str ]= "gpu" ,
5859 use_accelerator : bool = False ,
60+ hf_auto_model_additional_args : Optional [Dict ]= None ,
5961 * args ,
6062 ** kwargs
6163 ):
@@ -88,7 +90,7 @@ def __init__(
8890 self .model_args = model_args
8991 self .tokenizer = self .__prepare_tokenizer (model_args )
9092 self .torch_dtype = self .__prepare_dtype (model_args )
91- self .hf_model_config = self .__prepare_model_config (model_args )
93+ self .hf_model_config = self .__prepare_model_config (model_args , hf_auto_model_additional_args )
9294 self .quant_config = self .__prepare_quant_config (model_args )
9395 self .peft_config = self .__prepare_peft_config (model_args )
9496
@@ -106,11 +108,13 @@ def __init__(
106108 self .tokenizer .eos_token_id = self .backend_model .config .eos_token_id
107109 if self .tokenizer .pad_token_id is None :
108110 self .tokenizer .pad_token_id = self .tokenizer .eos_token_id
111+ if self .backend_model .config .pad_token_id is None :
112+ self .backend_model .config .pad_token_id = self .tokenizer .pad_token_id
109113
110114
111115 def __prepare_tokenizer (
112116 self ,
113- model_args
117+ model_args : ModelArguments ,
114118 ) -> Union [PreTrainedTokenizer , PreTrainedTokenizerFast ]:
115119 tokenizer_kwargs = {
116120 "cache_dir" : model_args .cache_dir ,
@@ -119,6 +123,8 @@ def __prepare_tokenizer(
119123 "use_auth_token" : True if model_args .use_auth_token else None ,
120124 "trust_remote_code" : model_args .trust_remote_code ,
121125 }
126+ if model_args .padding_side != 'auto' :
127+ tokenizer_kwargs ["padding_side" ] = model_args .padding_side
122128
123129 try :
124130 if model_args .tokenizer_name :
@@ -163,7 +169,7 @@ def __prepare_tokenizer(
163169
164170 def __prepare_dtype (
165171 self ,
166- model_args
172+ model_args : ModelArguments ,
167173 ) -> torch .dtype :
168174 if model_args .arch_type == 'text_regression' :
169175 if model_args .torch_dtype in ["auto" , None , "bf16" , "bfloat16" ]:
@@ -189,8 +195,23 @@ def __prepare_dtype(
189195
190196 def __prepare_model_config (
191197 self ,
192- model_args
198+ model_args : ModelArguments ,
199+ hf_auto_model_additional_args : Optional [Dict ]= None ,
193200 ):
201+ """Prepare model configuration for hf auto register,
202+ Parameters
203+ ----------
204+ model_args : ModelArguments
205+ LMFlow model arguments.
206+ hf_auto_model_additional_args : Optional[Dict], optional
207+ Special configurations such as `num_labels` in `AutoModelForSequenceClassification`
208+ (commonly used in reward modeling) will not preset in __prepare_model_config,
209+ so it should be passed in hf_auto_model_additional_args.
210+ Returns
211+ -------
212+ config : ModelConfig
213+ hf model config.
214+ """
194215 config_kwargs = {
195216 "torch_dtype" : self .torch_dtype ,
196217 "attn_implementation" : "flash_attention_2" if model_args .use_flash_attention else None ,
@@ -200,6 +221,9 @@ def __prepare_model_config(
200221 "trust_remote_code" : model_args .trust_remote_code ,
201222 "from_tf" : bool (".ckpt" in model_args .model_name_or_path ),
202223 }
224+ if hf_auto_model_additional_args is not None :
225+ config_kwargs .update (hf_auto_model_additional_args )
226+
203227 if model_args .config_name :
204228 config = AutoConfig .from_pretrained (model_args .config_name , ** config_kwargs )
205229 elif model_args .model_name_or_path :
@@ -217,7 +241,7 @@ def __prepare_model_config(
217241
218242 def __prepare_quant_config (
219243 self ,
220- model_args
244+ model_args : ModelArguments ,
221245 ):
222246 quant_config = None
223247 if model_args .use_qlora :
@@ -236,7 +260,7 @@ def __prepare_quant_config(
236260
237261 def __prepare_peft_config (
238262 self ,
239- model_args
263+ model_args : ModelArguments ,
240264 ):
241265 peft_config = None
242266 if model_args .use_lora :
@@ -267,7 +291,7 @@ def __prepare_peft_config(
267291
268292 def __model_module_inject (
269293 self ,
270- model_args
294+ model_args : ModelArguments ,
271295 ) -> None :
272296 """Override some model modules with custom implementations.
273297
@@ -286,8 +310,8 @@ def __model_module_inject(
286310
287311 def __prepare_model_for_training (
288312 self ,
289- model_args ,
290- hf_auto_model : HF_AUTOMODEL_TYPE
313+ model_args : ModelArguments ,
314+ hf_auto_model : HF_AUTOMODEL_TYPE ,
291315 ):
292316 # TODO: change to accelerate
293317 logger .info ("Preparing model for training" )
@@ -326,7 +350,7 @@ def __prepare_model_for_training(
326350
327351 def __prepare_model_for_inference (
328352 self ,
329- model_args ,
353+ model_args : ModelArguments ,
330354 hf_auto_model : HF_AUTOMODEL_TYPE ,
331355 use_accelerator ,
332356 ds_config
0 commit comments