Skip to content

Difficulty reproducing EHR-Bench results with EHR-R1-1.7B #5

@yundaqwe

Description

@yundaqwe

Hi,

Thank you for releasing the code, model, and benchmark; this is a very valuable contribution.

I am trying to reproduce the EHR-Bench results reported in Figure 7 in the paper using:

I followed the EHR input format described in the README and used the “decision_making” split provided by EHR-Bench. However, my reproduced results on EHR-Bench are consistently lower than those reported in the paper. Below are my task-level F1 scores:

Task F1
admissions 0.0607
chartevents 0.0819
datetimeevents 0.1602
diagnoses_ccs 0.2315
diagnoses_icd 0.0858
diagnosis 0.1231
diagnosis_ccs 0.2777
emar 0.1761
ingredientevents 0.5862
inputevents 0.1774
labevents 0.1668
medrecon 0.1383
medrecon_atc 0.2092
microbiologyevents 0.1295
omr 0.3863
outputevents 0.1629
poe 0.2101
prescriptions 0.2241
prescriptions_atc 0.4041
procedureevents 0.1987
procedures_ccs 0.1485
procedures_icd 0.0955
pyxis 0.0678
radiology 0.0916
services 0.2440
transfers 0.3054

I have attached my evaluation script for reference.

I would like to ask:

  1. Are there any additional preprocessing steps, prompt templates, or evaluation details that are not fully specified in the repository?
  2. Is there a recommended inference configuration (e.g., temperature) used for the reported EHR-Bench results?
  3. Are the reported results obtained using this exact HF model checkpoint, or were there additional steps involved?
  4. If possible, can you provide a script for reproducing it?
    Any clarification would be greatly appreciated.
    Thank you very much for your time!

Best regards,

My code:

#!/usr/bin/env python
# coding: utf-8

# In[1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM, SamplingParams
import torch
import os, json
import os
import re
import sys
import time
import json
import argparse
import random
import copy
import pandas as pd
from tqdm import tqdm
from collections import Counter
from joblib import Parallel, delayed
from test_v1 import score_func
from transformers import AutoConfig
import os
import torch
import math
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from models.base_model import Base_Model

tp = len(os.environ.get("CUDA_VISIBLE_DEVICES", "").split(","))
if tp == 1 and os.environ.get("CUDA_VISIBLE_DEVICES") is None:
    tp = torch.cuda.device_count()


class VLLM_Model(Base_Model):
    def __init__(self, model, gpu_memory_utilization=0.7, max_seq_len=32000):
        super().__init__(model_name)
        model_config = AutoConfig.from_pretrained(self.model_path, use_fast=False, trust_remote_code=True)
        self.max_seq_len = min(getattr(model_config, "max_position_embeddings", max_seq_len), max_seq_len)

        self.model=model
    def model_forward(self, inputs, sampling_params):
        input_ids = self.tokenizer(inputs)["input_ids"]
        input_ids = [input_id[-self.max_seq_len+sampling_params.max_tokens:] for input_id in input_ids]
        if self.peft_path:
            outputs = self.model.generate(prompt_token_ids=input_ids, sampling_params=sampling_params, lora_request=LoRARequest("lora", 1, self.peft_path), use_tqdm=False)
        else:
            outputs = self.model.generate(prompt_token_ids=input_ids, sampling_params=sampling_params, use_tqdm=False)
        return outputs

    def __call__(self, inputs, infer_args, logit_bias_words=None, enable_thinking=False):
        inputs = self.prepare_inputs(inputs, enable_thinking=enable_thinking)
        sampling_params = SamplingParams(logit_bias=None if not logit_bias_words else self.get_logit_bias(logit_bias_words), logprobs=1, **infer_args)
        outputs = self.model_forward(inputs, sampling_params)

        outputs = [
            [
                {
                    "trajectory": output.text.strip(),
                    "logits": math.exp(output.cumulative_logprob) if logit_bias_words else math.exp(sum([list(logprob.values())[0].logprob for logprob in output.logprobs[:-1]]))
                }
                for output in outputs[i].outputs
            ] 
            for i in range(len(inputs))
        ]


        return outputs


# In[3]:


from argparse import Namespace

args = Namespace(
    # data args
    resume=False,

    # dataset args
    root_dir="./datas",
    lazzy_mode=False,
    chunk_num=1,
    chunk_idx=0,
    cdm_candidate=False,

    # model args
    use_vllm=False,
    url=None,
    gpu_memory_utilization=0.8,
    max_seq_len=32000,

    # inference args
    prompt=False,
    batch=1,
    sample_num=1,
    temperature=0.0,
    top_p=0.95,
    top_k=20,
    max_new_tokens=1,
    think_prompt=True,
    direct_answer=True,
    oracle_reasoning=False,
)

batch=args.batch
sample_num=args.sample_num


# In[4]:


DECISION_MAKING_PROMPT = """Note that you should directly output the answer without any other information. If there are several items in the prediction, please separate them by `\\n`. For all predicted items, please use the item name instead of the item code. Do not output the code like ICD10 or ICD9."""

RISK_PREDICTION_PROMPT = """Note that you should directly output the answer without any other information. You can only choose one answer from the Candidate List."""

CANDIDATE_PROMPT = """You should choose the item from the candidate list below. Candidate List: {candidates}."""


# In[5]:

if __name__ == "__main__":
    import multiprocessing as mp

    mp.set_start_method("spawn", force=True)


    model_name = "EHR-R1-1.7B"
    model = LLM(
        model=model_name,
        tensor_parallel_size=tp,
        trust_remote_code=True,
        max_model_len=32000,
        max_seq_len_to_capture=32000,
        gpu_memory_utilization=0.7
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model_wrapped=VLLM_Model(model)


    # In[6]:


    from datasets import load_dataset

    ds = load_dataset(
        "json",
        data_files={"decision_making": "EHR-Bench/ehr_bench_decision_making.jsonl"},
        streaming=True,
    )




    def format_mimic_data(data,prompt=True,think_prompt=True,direct_answer=True,oracle_reasoning=False):
        input = data["input"]
        instruction = data["instruction"]
        candidates = data["candidates"]

        if data["task_info"]["task"] in ["lab_hyperkalemia"]:
            pass

        if prompt:
            PROMPT = DECISION_MAKING_PROMPT if data["task_info"]["task_type"] == "decision_making" else RISK_PREDICTION_PROMPT
            if candidates:
                input_prompt = "\n".join([input, instruction, CANDIDATE_PROMPT.format(candidates=candidates), PROMPT])
            else:
                input_prompt = "\n".join([input, instruction, PROMPT])
        else:
            input_prompt = input + "\n" + instruction

        if think_prompt:
            if direct_answer:
                input_prompt += "\n/no_think"
            else:
                input_prompt += "\n/think"

        if oracle_reasoning and data["task_info"].get("reasoning", None):
            try:
                reasoning = eval(data["task_info"]["reasoning"])["choices"][0]["message"]["content"]
                reasoning_wo_reasoning = reasoning.rsplit("Final Results", 1)[0]
                clean_data = {
                    "prompt": [
                        {"role": "system", "content": "You are a helpful assistant."},
                        {"role": "user", "content": input_prompt},
                        {"role": "assistant", "content": f"<think>{reasoning_wo_reasoning}"}
                    ]
                }
                data["task_info"]["reasoning"] = reasoning_wo_reasoning
            except:
                clean_data = {
                    "prompt": [
                        {"role": "system", "content": "You are a helpful assistant."},
                        {"role": "user", "content": input_prompt},
                    ]
                }

        else:
            clean_data = {
                "prompt": [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": input_prompt},
                ]
            }

        clean_data["task_info"] = copy.deepcopy(data["task_info"])
        clean_data["task_info"]["candidates"] = candidates
        return clean_data

    def get_data_chunk(sample_infos, chunk_num, chunk_idx):
        chunk_size = len(sample_infos) // chunk_num + 1
        return sample_infos[chunk_idx*chunk_size:(chunk_idx+1)*chunk_size]


    # In[8]:


    task="decision_making"
    def direct_answer_infer(args, model, examples):
        inputs = [example["prompt"] for example in examples]

        assert len(inputs) == 1

        if examples[0]["task_info"]["task_type"] == "risk_prediction":
            logit_bias_words = ["yes", "no"]
            max_new_tokens = 1

        else:
            logit_bias_words = None
            max_new_tokens = 128

        infer_args = {
            "n": args.sample_num,
            "max_tokens": max_new_tokens,
            "temperature": args.temperature,
            "top_p": args.top_p,
            "top_k": args.top_k
        }

        outputs = model(
            inputs,
            infer_args=infer_args,
            logit_bias_words=logit_bias_words,
            enable_thinking=False
        )
        outputs = [[sample for sample in output] for output in outputs] # flatten
        # trajectories = [f"The answer is {prediction}" for prediction in predictions]
        return outputs

    def reasoning_infer(args, model, examples):
        inputs = [example["prompt"] for example in examples]

        if examples[0]["task_info"]["task_type"] == "risk_prediction":
            logit_bias_words = ["yes", "no"]
            max_new_tokens = 1
            enable_thinking = False

        else:
            logit_bias_words = None
            max_new_tokens = args.max_new_tokens
            enable_thinking = True

        infer_args = {
            "n": args.sample_num,
            "max_tokens": max_new_tokens,
            "temperature": args.temperature,
            "top_p": args.top_p,
            "top_k": args.top_k
        }

        outputs = model(
            inputs,
            infer_args=infer_args,
            logit_bias_words=logit_bias_words,
            enable_thinking=enable_thinking
        )

        def split_reasoning_trajectory(model_name, output):
            try:
                if "reasoning" in output:
                    return output

                elif "qwen" in model_name or "m2" in model_name:
                    reasoning, trajectory = output["trajectory"].split("</think>")
                    reasoning += "</think>"
                    trajectory = trajectory.strip()
                    output["trajectory"] = trajectory
                    output["reasoning"] = reasoning
                elif "gpt_oss" in model_name or "gpt-oss" in model_name:
                    trajectory = re.findall(r"<\|channel\|>final<\|message\|>(.*?)<\|return\|>", output["trajectory"], re.DOTALL)[0]
                    reasoning = re.findall(r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>", output["trajectory"], re.DOTALL)[0]
                    trajectory = trajectory.strip()
                    output["trajectory"] = trajectory
                    output["reasoning"] = reasoning
                else:
                    raise NotImplementedError

                return output

            except:
                return output
        outputs = [[split_reasoning_trajectory(args.model_name_or_path, sample) for sample in output] for output in outputs] # flatten
        # trajectories = [f"The answer is {prediction}" for prediction in predictions]
        return outputs


    # In[11]:


    def iter_batches(iterable, batch_size):
        batch = []
        for item in iterable:
            batch.append(item)
            if len(batch) == batch_size:
                yield batch
                batch = []
        if batch:
            yield batch


    # In[12]:


    def get_uid(example):

        return json.dumps(example["task_info"], ensure_ascii=False, sort_keys=True)


    cache_path = f"./cache/ehrbench_{task}.jsonl"

    done_uids = set()
    if os.path.exists(cache_path):
        with open(cache_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                    done_uids.add(obj.get("uid"))
                except json.JSONDecodeError:

                    continue


    # In[13]:


    print(f"Resume: found {len(done_uids)} completed examples in cache.")


    # In[14]:


    cache_fp = open(cache_path, "a", encoding="utf-8")


    # In[ ]:


    import gc
    import torch
    torch.cuda.empty_cache()



    for batch_id, raw_batch in enumerate(
        tqdm(iter_batches(ds[task], batch), desc=f"Inference with k={sample_num}")
    ):

        # format
        examples_all = [format_mimic_data(ex) for ex in raw_batch]

        # filter unfinished
        examples = []
        for ex in examples_all:
            uid = get_uid(ex)
            if uid not in done_uids:
                ex["_uid"] = uid
                examples.append(ex)

        if len(examples) == 0:
            continue

        start_time = time.time()
        if args.direct_answer:
            trajectories = direct_answer_infer(args, model_wrapped, examples)
        else:
            trajectories = reasoning_infer(args, model_wrapped, examples)
        end_time = time.time()

        example_logs = [
            {
                "uid": ex["_uid"],
                "task_info": ex["task_info"],
                "outputs": trajectories[i],
                "meta": {
                    "task": task,
                    "batch_id": batch_id,
                    "inference_sec": end_time - start_time,
                }
            }
            for i, ex in enumerate(examples)
        ]

        example_logs = score_func(example_logs)

        for row in example_logs:
            cache_fp.write(json.dumps(row, ensure_ascii=False) + "\n")
            done_uids.add(row["uid"])  #
        cache_fp.flush()
        torch.cuda.empty_cache()
        del examples, trajectories, example_logs
        gc.collect()

    cache_fp.close()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions