Skip to content

Commit 0911ed4

Browse files
committed
[data] data protocol
1 parent bac6017 commit 0911ed4

12 files changed

Lines changed: 1183 additions & 59 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ data/
2727
output_models
2828
adapter_model/
2929

30+
# output data
31+
output_data/
32+
3033
# Distribution / packaging
3134
.Python
3235
build/

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ bash ./scripts/run_chatbot.sh output_models/finetuned_gpt2
290290
>```bash
291291
>bash ./scripts/run_sglang_inference.sh
292292
>```
293-
>
293+
> Note: If you encounter error ModuleNotFoundError: No module named 'common_ops' when using SGLang, please try `apt-get update` and then `apt install numactl`.
294294
> </details>
295295
296296
### Deployment

examples/rm_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def main():
4242
dataset,
4343
)
4444

45-
if pipeline_args.save_results:
46-
res.save(pipeline_args.results_path)
45+
if pipeline_args.save_inference_results:
46+
res.save(pipeline_args.inference_results_path)
4747

4848

4949
if __name__ == "__main__":

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ evaluate==0.4.0
1212
bitsandbytes>=0.40.0
1313
pydantic
1414
accelerate>=0.27.2
15-
einops>=0.6.1
15+
einops>=0.6.1
16+
tensordict

scripts/archive/run_rm_inference.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,6 @@ accelerate launch --config_file configs/accelerator_multigpu_config.yaml \
6767
--overwrite_cache True \
6868
--conversation_template ${conversation_template} \
6969
--preprocessing_num_workers 16 \
70-
--save_results True \
71-
--results_path ${output_file_path} \
70+
--save_inference_results True \
71+
--inference_results_path ${output_file_path} \
7272
2>&1 | tee ${log_dir}/rm_inference.log

scripts/archive/run_vllm_inference.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ python examples/vllm_inference.py \
6969
--temperature 1.0 \
7070
--top_p 0.9 \
7171
--max_new_tokens 1024 \
72-
--save_results True \
73-
--results_path ${output_file_path} \
72+
--save_inference_results True \
73+
--inference_results_path ${output_file_path} \
7474
--enable_decode_inference_result False \
7575
--vllm_gpu_memory_utilization 0.95 \
7676
--vllm_tensor_parallel_size 2 \

scripts/run_sglang_inference.sh

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
python examples/sglang_inference.py \
2-
--model_name_or_path Qwen/Qwen3-4B-Instruct-2507 \
3-
--dataset_path data/alpaca/test_conversation \
4-
--output_dir output_data/sglang_inference_results \
5-
--output_file_name results.json \
2+
--model_name_or_path Qwen/Qwen3-0.6B \
3+
--dataset_path data/alpaca/prompt_only \
64
--inference_engine sglang \
75
--inference_gpu_memory_utilization 0.8 \
86
--num_output_sequences 2 \
97
--temperature 1.0 \
108
--max_new_tokens 2048 \
119
--top_p 0.95 \
1210
--random_seed 42 \
13-
--save_results True \
14-
--results_path output_data/sglang_inference_results/results.json
11+
--save_inference_results True \
12+
--inference_results_path output_data/sglang_inference_results/results.json

src/lmflow/args.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,8 +1053,11 @@ class InferencerArguments:
10531053
)
10541054

10551055
# Args for result saving
1056-
save_results: Optional[bool] = field(default=False, metadata={"help": "Whether to save inference results."})
1057-
results_path: Optional[str] = field(default=None, metadata={"help": "The path of inference results."})
1056+
save_results: Optional[bool] = field(default=None, metadata={"help": "Whether to save results."})
1057+
results_path: Optional[str] = field(default=None, metadata={"help": "The path of results."})
1058+
1059+
save_inference_results: Optional[bool] = field(default=False, metadata={"help": "Whether to save inference results."})
1060+
inference_results_path: Optional[str] = field(default=None, metadata={"help": "The path of inference results."})
10581061

10591062
def __post_init__(self):
10601063
if self.use_accelerator is not None:
@@ -1063,15 +1066,23 @@ def __post_init__(self):
10631066
"It will not take effect and will be removed in a future version, "
10641067
"since LMFlow now can automatically detect whether is in Accelerate or Deepspeed environment."
10651068
)
1066-
1069+
10671070
if self.save_results:
1068-
if self.results_path is None:
1069-
raise ValueError("Need to specify results_path when save_results is True.")
1071+
logger.warning("`save_results` is deprecated and will be removed in a future version. Please use `save_inference_results` instead.")
1072+
self.save_inference_results = self.save_results
1073+
1074+
if self.results_path:
1075+
logger.warning("`results_path` is deprecated and will be removed in a future version. Please use `inference_results_path` instead.")
1076+
self.inference_results_path = self.results_path
1077+
1078+
if self.save_inference_results:
1079+
if self.inference_results_path is None:
1080+
raise ValueError("Need to specify inference_results_path when save_inference_results is True.")
10701081
else:
1071-
if not self.results_path.endswith(".json"):
1072-
raise ValueError("The results_path must be a json file.")
1082+
if not self.inference_results_path.endswith(".json"):
1083+
raise ValueError("The inference_results_path must be a json file.")
10731084
else:
1074-
Path(self.results_path).parent.mkdir(parents=True, exist_ok=True)
1085+
Path(self.inference_results_path).parent.mkdir(parents=True, exist_ok=True)
10751086

10761087
if self.use_vllm is True:
10771088
logger.warning(
@@ -1352,6 +1363,7 @@ class IterativeDPOAlignerArguments(IterativeAlignerArguments, DPOv2AlignerArgume
13521363
"evaluator": EvaluatorArguments,
13531364
"inferencer": InferencerArguments,
13541365
"vllm_inferencer": InferencerArguments,
1366+
"sglang_inferencer": InferencerArguments,
13551367
"rm_inferencer": InferencerArguments,
13561368
"raft_aligner": RaftAlignerArguments,
13571369
"dpo_aligner": DPOAlignerArguments,

src/lmflow/models/hf_decoder_model.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
from typing import Literal, Optional, Union
2020

21+
import numpy as np
2122
import torch
2223
from peft import PeftModel
2324
from transformers import (
@@ -40,6 +41,7 @@
4041
from lmflow.utils.deprecated import deprecated_args
4142
from lmflow.utils.envs import is_accelerate_env
4243
from lmflow.utils.versioning import is_flash_attn_available, is_ray_available, is_vllm_available
44+
from lmflow.utils.protocol import DataProto
4345

4446
logger = logging.getLogger(__name__)
4547

@@ -286,7 +288,7 @@ def decode(self, input, **kwargs) -> Union[str, list[str]]:
286288
)
287289
def inference(
288290
self,
289-
inputs: Union[str, list[str], torch.Tensor],
291+
inputs: Union[str, list[str], torch.Tensor, DataProto],
290292
sampling_params: Optional[Union[dict, "SamplingParams"]] = None,
291293
return_logprob: bool = False,
292294
release_gpu: bool = False,
@@ -296,16 +298,17 @@ def inference(
296298
enable_deterministic_inference: bool = False,
297299
attention_backend: Optional[str] = None,
298300
**kwargs,
299-
):
301+
) -> Union[list[VLLMInferenceResultWithInput] | DataProto]:
300302
"""
301303
Perform generation process of the model.
302304
303305
Parameters
304306
------------
305-
inputs : Union[str, list[str], torch.Tensor]
307+
inputs : Union[str, list[str], torch.Tensor, DataProto]
306308
The sequence used as a prompt for the generation or as model inputs to the model.
307-
When the inference engine is "vllm" or "sglang", this should be a string or a list of strings.
309+
When the inference engine is "vllm", this should be a string or a list of strings.
308310
When the inference engine is "huggingface", this should be a tensor.
311+
When the inference engine is "sglang", this should be a DataProto.
309312
sampling_params : Optional[Union[dict, "SamplingParams"]], optional
310313
The sampling parameters to use, by default None.
311314
return_logprob : bool, optional
@@ -345,7 +348,6 @@ def inference(
345348
elif inference_engine == "sglang":
346349
res = self.__sglang_inference(
347350
inputs=inputs,
348-
sampling_params=sampling_params,
349351
return_logprob=return_logprob,
350352
)
351353
else:
@@ -439,21 +441,18 @@ def __vllm_inference(
439441

440442
def __sglang_inference(
441443
self,
442-
inputs: list[str],
443-
sampling_params: Optional[dict] = None,
444+
inputs: DataProto,
444445
return_logprob: bool = False,
445-
):
446+
) -> DataProto:
446447
"""Perform SGLang inference process of the model."""
447448
sglang_outputs = self.backend_model_for_inference.generate(
448-
prompt=inputs,
449-
sampling_params=sampling_params,
449+
prompt=inputs.non_tensor_batch["inputs"].tolist(), # use tensor instead of str later
450+
sampling_params=inputs.meta_info["sampling_params"],
450451
return_logprob=return_logprob,
451452
)
452-
# TODO: unified lmflow sample format
453-
for idx, output in enumerate(sglang_outputs):
454-
output["input"] = inputs[idx]
455-
output["output"] = output.pop("text")
456-
return sglang_outputs
453+
inputs.non_tensor_batch["outputs"] = [output["text"] for output in sglang_outputs]
454+
# TODO: padding for batching the output ids; generatin details
455+
return inputs
457456

458457
@deprecated_args(
459458
use_vllm={
@@ -471,7 +470,8 @@ def prepare_inputs_for_inference(
471470
apply_chat_template: bool = True,
472471
inference_engine: Literal["huggingface", "vllm", "sglang"] = "huggingface",
473472
enable_distributed_inference: bool = False,
474-
) -> Union[list[str], "ray.data.Dataset"]:
473+
sampling_params: Optional[dict] = None,
474+
) -> Union[list[str], "ray.data.Dataset", DataProto]:
475475
if dataset.get_type() == "text_only":
476476
if apply_chat_template:
477477
dataset = dataset.map(
@@ -551,9 +551,20 @@ def preprocess_conversation(sample):
551551
inference_inputs
552552
) # -> dict[str, np.ndarray], {"item": array(['...', '...', '...'])}
553553

554-
if inference_engine == "sglang" and self.tokenizer.bos_token:
555-
# in consistent with sglang bench_serving.py demo
556-
inference_inputs = [sentence.replace(self.tokenizer.bos_token, "") for sentence in inference_inputs]
554+
if inference_engine == "sglang":
555+
if self.tokenizer.bos_token:
556+
# in consistent with sglang bench_serving.py demo
557+
inference_inputs = [sentence.replace(self.tokenizer.bos_token, "") for sentence in inference_inputs]
558+
559+
# currently only test dataproto on sglang inference
560+
inference_inputs = np.array(inference_inputs)
561+
inference_inputs = DataProto.from_single_dict(
562+
data={"inputs": inference_inputs},
563+
meta_info={"sampling_params": {**sampling_params, "n": 1}, "actual_n_rollouts": sampling_params["n"]}
564+
)
565+
566+
# handling n>1 since we don't want one-to-many mapping. Later this will be applied to all inference engines.
567+
inference_inputs = inference_inputs.repeat(sampling_params["n"])
557568

558569
return inference_inputs
559570

src/lmflow/pipeline/sglang_inferencer.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from lmflow.models.hf_decoder_model import HFDecoderModel
1616
from lmflow.pipeline.base_pipeline import BasePipeline
1717
from lmflow.utils.versioning import is_sglang_available
18+
from lmflow.utils.protocol import DataProto
1819

1920
logger = logging.getLogger(__name__)
2021

@@ -64,7 +65,7 @@ def inference(
6465
dataset: Dataset,
6566
release_gpu: bool = False,
6667
inference_args: Optional[InferencerArguments] = None,
67-
):
68+
) -> DataProto:
6869
if inference_args:
6970
logger.warning("Overriding the default inference arguments with the provided arguments in .inference()")
7071
sampling_params = self._parse_args_to_sampling_params(inference_args)
@@ -76,13 +77,11 @@ def inference(
7677
dataset=dataset,
7778
apply_chat_template=self.inferencer_args.apply_chat_template,
7879
inference_engine="sglang",
80+
sampling_params=sampling_params,
7981
)
80-
# handling n>1 since we don't want one-to-many mapping
81-
model_input = [sample for sample in model_input for _ in range(sampling_params["n"])]
8282

8383
outputs = model.inference(
8484
inputs=model_input,
85-
sampling_params=sampling_params.copy().update({"n": 1}),
8685
return_logprob=self.inferencer_args.return_logprob,
8786
release_gpu=release_gpu,
8887
inference_engine="sglang",
@@ -92,26 +91,24 @@ def inference(
9291
attention_backend=self.inferencer_args.attention_backend,
9392
)
9493

95-
if self.inferencer_args.save_results:
96-
self.save_inference_results(outputs, self.inferencer_args.results_path)
94+
if self.inferencer_args.save_inference_results:
95+
self.save_inference_results(outputs, self.inferencer_args.inference_results_path)
9796

9897
return outputs
9998

10099
def save_inference_results(
101100
self,
102-
outputs: Union[list[list[str]], list[list[list[int]]]],
103-
save_file_path: str,
101+
outputs: DataProto,
102+
inference_results_path: str,
104103
):
105-
with open(save_file_path, "w", encoding="utf-8") as f:
106-
json.dump(outputs, f, ensure_ascii=False, indent=4)
107-
108-
logger.info(f"Inference results are saved to {save_file_path}.")
104+
if not inference_results_path.endswith(".pkl"):
105+
logger.warning(f"The inference results path must be a pickle file. Change the path to {inference_results_path}.pkl")
106+
inference_results_path = inference_results_path + ".pkl"
107+
outputs.save_to_disk(inference_results_path)
108+
logger.info(f"Inference results are saved to {inference_results_path}.")
109109

110110
def load_inference_results(
111111
self,
112-
results_path: str,
113-
) -> Union[list[list[str]], list[list[list[int]]]]:
114-
with open(results_path) as f:
115-
results = json.load(f)
116-
117-
return results
112+
inference_results_path: str,
113+
) -> DataProto:
114+
return DataProto.load_from_disk(inference_results_path)

0 commit comments

Comments
 (0)