-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
Hi,
Thank you for releasing the code, model, and benchmark; this is a very valuable contribution.
I am trying to reproduce the EHR-Bench results reported in Figure 7 in the paper using:
- Model: https://huggingface.co/BlueZeros/EHR-R1-1.7B
- Dataset: https://huggingface.co/datasets/BlueZeros/EHR-Bench
I followed the EHR input format described in the README and used the “decision_making” split provided by EHR-Bench. However, my reproduced results on EHR-Bench are consistently lower than those reported in the paper. Below are my task-level F1 scores:
| Task | F1 |
|---|---|
| admissions | 0.0607 |
| chartevents | 0.0819 |
| datetimeevents | 0.1602 |
| diagnoses_ccs | 0.2315 |
| diagnoses_icd | 0.0858 |
| diagnosis | 0.1231 |
| diagnosis_ccs | 0.2777 |
| emar | 0.1761 |
| ingredientevents | 0.5862 |
| inputevents | 0.1774 |
| labevents | 0.1668 |
| medrecon | 0.1383 |
| medrecon_atc | 0.2092 |
| microbiologyevents | 0.1295 |
| omr | 0.3863 |
| outputevents | 0.1629 |
| poe | 0.2101 |
| prescriptions | 0.2241 |
| prescriptions_atc | 0.4041 |
| procedureevents | 0.1987 |
| procedures_ccs | 0.1485 |
| procedures_icd | 0.0955 |
| pyxis | 0.0678 |
| radiology | 0.0916 |
| services | 0.2440 |
| transfers | 0.3054 |
I have attached my evaluation script for reference.
I would like to ask:
- Are there any additional preprocessing steps, prompt templates, or evaluation details that are not fully specified in the repository?
- Is there a recommended inference configuration (e.g., temperature) used for the reported EHR-Bench results?
- Are the reported results obtained using this exact HF model checkpoint, or were there additional steps involved?
- If possible, can you provide a script for reproducing it?
Any clarification would be greatly appreciated.
Thank you very much for your time!
Best regards,
My code:
#!/usr/bin/env python
# coding: utf-8
# In[1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM, SamplingParams
import torch
import os, json
import os
import re
import sys
import time
import json
import argparse
import random
import copy
import pandas as pd
from tqdm import tqdm
from collections import Counter
from joblib import Parallel, delayed
from test_v1 import score_func
from transformers import AutoConfig
import os
import torch
import math
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from models.base_model import Base_Model
tp = len(os.environ.get("CUDA_VISIBLE_DEVICES", "").split(","))
if tp == 1 and os.environ.get("CUDA_VISIBLE_DEVICES") is None:
tp = torch.cuda.device_count()
class VLLM_Model(Base_Model):
def __init__(self, model, gpu_memory_utilization=0.7, max_seq_len=32000):
super().__init__(model_name)
model_config = AutoConfig.from_pretrained(self.model_path, use_fast=False, trust_remote_code=True)
self.max_seq_len = min(getattr(model_config, "max_position_embeddings", max_seq_len), max_seq_len)
self.model=model
def model_forward(self, inputs, sampling_params):
input_ids = self.tokenizer(inputs)["input_ids"]
input_ids = [input_id[-self.max_seq_len+sampling_params.max_tokens:] for input_id in input_ids]
if self.peft_path:
outputs = self.model.generate(prompt_token_ids=input_ids, sampling_params=sampling_params, lora_request=LoRARequest("lora", 1, self.peft_path), use_tqdm=False)
else:
outputs = self.model.generate(prompt_token_ids=input_ids, sampling_params=sampling_params, use_tqdm=False)
return outputs
def __call__(self, inputs, infer_args, logit_bias_words=None, enable_thinking=False):
inputs = self.prepare_inputs(inputs, enable_thinking=enable_thinking)
sampling_params = SamplingParams(logit_bias=None if not logit_bias_words else self.get_logit_bias(logit_bias_words), logprobs=1, **infer_args)
outputs = self.model_forward(inputs, sampling_params)
outputs = [
[
{
"trajectory": output.text.strip(),
"logits": math.exp(output.cumulative_logprob) if logit_bias_words else math.exp(sum([list(logprob.values())[0].logprob for logprob in output.logprobs[:-1]]))
}
for output in outputs[i].outputs
]
for i in range(len(inputs))
]
return outputs
# In[3]:
from argparse import Namespace
args = Namespace(
# data args
resume=False,
# dataset args
root_dir="./datas",
lazzy_mode=False,
chunk_num=1,
chunk_idx=0,
cdm_candidate=False,
# model args
use_vllm=False,
url=None,
gpu_memory_utilization=0.8,
max_seq_len=32000,
# inference args
prompt=False,
batch=1,
sample_num=1,
temperature=0.0,
top_p=0.95,
top_k=20,
max_new_tokens=1,
think_prompt=True,
direct_answer=True,
oracle_reasoning=False,
)
batch=args.batch
sample_num=args.sample_num
# In[4]:
DECISION_MAKING_PROMPT = """Note that you should directly output the answer without any other information. If there are several items in the prediction, please separate them by `\\n`. For all predicted items, please use the item name instead of the item code. Do not output the code like ICD10 or ICD9."""
RISK_PREDICTION_PROMPT = """Note that you should directly output the answer without any other information. You can only choose one answer from the Candidate List."""
CANDIDATE_PROMPT = """You should choose the item from the candidate list below. Candidate List: {candidates}."""
# In[5]:
if __name__ == "__main__":
import multiprocessing as mp
mp.set_start_method("spawn", force=True)
model_name = "EHR-R1-1.7B"
model = LLM(
model=model_name,
tensor_parallel_size=tp,
trust_remote_code=True,
max_model_len=32000,
max_seq_len_to_capture=32000,
gpu_memory_utilization=0.7
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_wrapped=VLLM_Model(model)
# In[6]:
from datasets import load_dataset
ds = load_dataset(
"json",
data_files={"decision_making": "EHR-Bench/ehr_bench_decision_making.jsonl"},
streaming=True,
)
def format_mimic_data(data,prompt=True,think_prompt=True,direct_answer=True,oracle_reasoning=False):
input = data["input"]
instruction = data["instruction"]
candidates = data["candidates"]
if data["task_info"]["task"] in ["lab_hyperkalemia"]:
pass
if prompt:
PROMPT = DECISION_MAKING_PROMPT if data["task_info"]["task_type"] == "decision_making" else RISK_PREDICTION_PROMPT
if candidates:
input_prompt = "\n".join([input, instruction, CANDIDATE_PROMPT.format(candidates=candidates), PROMPT])
else:
input_prompt = "\n".join([input, instruction, PROMPT])
else:
input_prompt = input + "\n" + instruction
if think_prompt:
if direct_answer:
input_prompt += "\n/no_think"
else:
input_prompt += "\n/think"
if oracle_reasoning and data["task_info"].get("reasoning", None):
try:
reasoning = eval(data["task_info"]["reasoning"])["choices"][0]["message"]["content"]
reasoning_wo_reasoning = reasoning.rsplit("Final Results", 1)[0]
clean_data = {
"prompt": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": input_prompt},
{"role": "assistant", "content": f"<think>{reasoning_wo_reasoning}"}
]
}
data["task_info"]["reasoning"] = reasoning_wo_reasoning
except:
clean_data = {
"prompt": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": input_prompt},
]
}
else:
clean_data = {
"prompt": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": input_prompt},
]
}
clean_data["task_info"] = copy.deepcopy(data["task_info"])
clean_data["task_info"]["candidates"] = candidates
return clean_data
def get_data_chunk(sample_infos, chunk_num, chunk_idx):
chunk_size = len(sample_infos) // chunk_num + 1
return sample_infos[chunk_idx*chunk_size:(chunk_idx+1)*chunk_size]
# In[8]:
task="decision_making"
def direct_answer_infer(args, model, examples):
inputs = [example["prompt"] for example in examples]
assert len(inputs) == 1
if examples[0]["task_info"]["task_type"] == "risk_prediction":
logit_bias_words = ["yes", "no"]
max_new_tokens = 1
else:
logit_bias_words = None
max_new_tokens = 128
infer_args = {
"n": args.sample_num,
"max_tokens": max_new_tokens,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k
}
outputs = model(
inputs,
infer_args=infer_args,
logit_bias_words=logit_bias_words,
enable_thinking=False
)
outputs = [[sample for sample in output] for output in outputs] # flatten
# trajectories = [f"The answer is {prediction}" for prediction in predictions]
return outputs
def reasoning_infer(args, model, examples):
inputs = [example["prompt"] for example in examples]
if examples[0]["task_info"]["task_type"] == "risk_prediction":
logit_bias_words = ["yes", "no"]
max_new_tokens = 1
enable_thinking = False
else:
logit_bias_words = None
max_new_tokens = args.max_new_tokens
enable_thinking = True
infer_args = {
"n": args.sample_num,
"max_tokens": max_new_tokens,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k
}
outputs = model(
inputs,
infer_args=infer_args,
logit_bias_words=logit_bias_words,
enable_thinking=enable_thinking
)
def split_reasoning_trajectory(model_name, output):
try:
if "reasoning" in output:
return output
elif "qwen" in model_name or "m2" in model_name:
reasoning, trajectory = output["trajectory"].split("</think>")
reasoning += "</think>"
trajectory = trajectory.strip()
output["trajectory"] = trajectory
output["reasoning"] = reasoning
elif "gpt_oss" in model_name or "gpt-oss" in model_name:
trajectory = re.findall(r"<\|channel\|>final<\|message\|>(.*?)<\|return\|>", output["trajectory"], re.DOTALL)[0]
reasoning = re.findall(r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>", output["trajectory"], re.DOTALL)[0]
trajectory = trajectory.strip()
output["trajectory"] = trajectory
output["reasoning"] = reasoning
else:
raise NotImplementedError
return output
except:
return output
outputs = [[split_reasoning_trajectory(args.model_name_or_path, sample) for sample in output] for output in outputs] # flatten
# trajectories = [f"The answer is {prediction}" for prediction in predictions]
return outputs
# In[11]:
def iter_batches(iterable, batch_size):
batch = []
for item in iterable:
batch.append(item)
if len(batch) == batch_size:
yield batch
batch = []
if batch:
yield batch
# In[12]:
def get_uid(example):
return json.dumps(example["task_info"], ensure_ascii=False, sort_keys=True)
cache_path = f"./cache/ehrbench_{task}.jsonl"
done_uids = set()
if os.path.exists(cache_path):
with open(cache_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
done_uids.add(obj.get("uid"))
except json.JSONDecodeError:
continue
# In[13]:
print(f"Resume: found {len(done_uids)} completed examples in cache.")
# In[14]:
cache_fp = open(cache_path, "a", encoding="utf-8")
# In[ ]:
import gc
import torch
torch.cuda.empty_cache()
for batch_id, raw_batch in enumerate(
tqdm(iter_batches(ds[task], batch), desc=f"Inference with k={sample_num}")
):
# format
examples_all = [format_mimic_data(ex) for ex in raw_batch]
# filter unfinished
examples = []
for ex in examples_all:
uid = get_uid(ex)
if uid not in done_uids:
ex["_uid"] = uid
examples.append(ex)
if len(examples) == 0:
continue
start_time = time.time()
if args.direct_answer:
trajectories = direct_answer_infer(args, model_wrapped, examples)
else:
trajectories = reasoning_infer(args, model_wrapped, examples)
end_time = time.time()
example_logs = [
{
"uid": ex["_uid"],
"task_info": ex["task_info"],
"outputs": trajectories[i],
"meta": {
"task": task,
"batch_id": batch_id,
"inference_sec": end_time - start_time,
}
}
for i, ex in enumerate(examples)
]
example_logs = score_func(example_logs)
for row in example_logs:
cache_fp.write(json.dumps(row, ensure_ascii=False) + "\n")
done_uids.add(row["uid"]) #
cache_fp.flush()
torch.cuda.empty_cache()
del examples, trajectories, example_logs
gc.collect()
cache_fp.close()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels