1818import os
1919from typing import Literal , Optional , Union
2020
21+ import numpy as np
2122import torch
2223from peft import PeftModel
2324from transformers import (
4041from lmflow .utils .deprecated import deprecated_args
4142from lmflow .utils .envs import is_accelerate_env
4243from lmflow .utils .versioning import is_flash_attn_available , is_ray_available , is_vllm_available
44+ from lmflow .utils .protocol import DataProto
4345
4446logger = logging .getLogger (__name__ )
4547
@@ -286,7 +288,7 @@ def decode(self, input, **kwargs) -> Union[str, list[str]]:
286288 )
287289 def inference (
288290 self ,
289- inputs : Union [str , list [str ], torch .Tensor ],
291+ inputs : Union [str , list [str ], torch .Tensor , DataProto ],
290292 sampling_params : Optional [Union [dict , "SamplingParams" ]] = None ,
291293 return_logprob : bool = False ,
292294 release_gpu : bool = False ,
@@ -296,16 +298,17 @@ def inference(
296298 enable_deterministic_inference : bool = False ,
297299 attention_backend : Optional [str ] = None ,
298300 ** kwargs ,
299- ):
301+ ) -> Union [ list [ VLLMInferenceResultWithInput ] | DataProto ] :
300302 """
301303 Perform generation process of the model.
302304
303305 Parameters
304306 ------------
305- inputs : Union[str, list[str], torch.Tensor]
307+ inputs : Union[str, list[str], torch.Tensor, DataProto ]
306308 The sequence used as a prompt for the generation or as model inputs to the model.
307- When the inference engine is "vllm" or "sglang" , this should be a string or a list of strings.
309+ When the inference engine is "vllm", this should be a string or a list of strings.
308310 When the inference engine is "huggingface", this should be a tensor.
311+ When the inference engine is "sglang", this should be a DataProto.
309312 sampling_params : Optional[Union[dict, "SamplingParams"]], optional
310313 The sampling parameters to use, by default None.
311314 return_logprob : bool, optional
@@ -345,7 +348,6 @@ def inference(
345348 elif inference_engine == "sglang" :
346349 res = self .__sglang_inference (
347350 inputs = inputs ,
348- sampling_params = sampling_params ,
349351 return_logprob = return_logprob ,
350352 )
351353 else :
@@ -439,21 +441,18 @@ def __vllm_inference(
439441
440442 def __sglang_inference (
441443 self ,
442- inputs : list [str ],
443- sampling_params : Optional [dict ] = None ,
444+ inputs : DataProto ,
444445 return_logprob : bool = False ,
445- ):
446+ ) -> DataProto :
446447 """Perform SGLang inference process of the model."""
447448 sglang_outputs = self .backend_model_for_inference .generate (
448- prompt = inputs ,
449- sampling_params = sampling_params ,
449+ prompt = inputs . non_tensor_batch [ "inputs" ]. tolist (), # use tensor instead of str later
450+ sampling_params = inputs . meta_info [ " sampling_params" ] ,
450451 return_logprob = return_logprob ,
451452 )
452- # TODO: unified lmflow sample format
453- for idx , output in enumerate (sglang_outputs ):
454- output ["input" ] = inputs [idx ]
455- output ["output" ] = output .pop ("text" )
456- return sglang_outputs
453+ inputs .non_tensor_batch ["outputs" ] = [output ["text" ] for output in sglang_outputs ]
454+ # TODO: padding for batching the output ids; generatin details
455+ return inputs
457456
458457 @deprecated_args (
459458 use_vllm = {
@@ -471,7 +470,8 @@ def prepare_inputs_for_inference(
471470 apply_chat_template : bool = True ,
472471 inference_engine : Literal ["huggingface" , "vllm" , "sglang" ] = "huggingface" ,
473472 enable_distributed_inference : bool = False ,
474- ) -> Union [list [str ], "ray.data.Dataset" ]:
473+ sampling_params : Optional [dict ] = None ,
474+ ) -> Union [list [str ], "ray.data.Dataset" , DataProto ]:
475475 if dataset .get_type () == "text_only" :
476476 if apply_chat_template :
477477 dataset = dataset .map (
@@ -551,9 +551,20 @@ def preprocess_conversation(sample):
551551 inference_inputs
552552 ) # -> dict[str, np.ndarray], {"item": array(['...', '...', '...'])}
553553
554- if inference_engine == "sglang" and self .tokenizer .bos_token :
555- # in consistent with sglang bench_serving.py demo
556- inference_inputs = [sentence .replace (self .tokenizer .bos_token , "" ) for sentence in inference_inputs ]
554+ if inference_engine == "sglang" :
555+ if self .tokenizer .bos_token :
556+ # in consistent with sglang bench_serving.py demo
557+ inference_inputs = [sentence .replace (self .tokenizer .bos_token , "" ) for sentence in inference_inputs ]
558+
559+ # currently only test dataproto on sglang inference
560+ inference_inputs = np .array (inference_inputs )
561+ inference_inputs = DataProto .from_single_dict (
562+ data = {"inputs" : inference_inputs },
563+ meta_info = {"sampling_params" : {** sampling_params , "n" : 1 }, "actual_n_rollouts" : sampling_params ["n" ]}
564+ )
565+
566+ # handling n>1 since we don't want one-to-many mapping. Later this will be applied to all inference engines.
567+ inference_inputs = inference_inputs .repeat (sampling_params ["n" ])
557568
558569 return inference_inputs
559570
0 commit comments