From eb9e55b31b31696759e646a221542b7179fc47f9 Mon Sep 17 00:00:00 2001 From: Hu Chong <45162720+1854039@users.noreply.github.com> Date: Sat, 12 Jul 2025 17:33:22 +0800 Subject: [PATCH] Add DocRED-based NER script --- README.md | 9 +++ experiments/docred_ner.py | 110 +++++++++++++++++++++++++++++++++++++ experiments/wikiann_ner.py | 99 +++++++++++++++++++++++++++++++++ 3 files changed, 218 insertions(+) create mode 100644 experiments/docred_ner.py create mode 100644 experiments/wikiann_ner.py diff --git a/README.md b/README.md index b1a67f0..b55f6ee 100644 --- a/README.md +++ b/README.md @@ -26,3 +26,12 @@ An example evaluation script is provided under `experiments/evaluate_entity_extr ```bash python experiments/evaluate_entity_extraction.py --dataset tacred --split validation --model ``` + +For document level entity annotation with the DocRED dataset using an OpenAI compatible API you can run: + +```bash +python experiments/docred_ner.py \ + --models gpt-3.5-turbo,gpt-4o \ + --split validation --limit 100 \ + --base_url --api_key +``` diff --git a/experiments/docred_ner.py b/experiments/docred_ner.py new file mode 100644 index 0000000..113b7df --- /dev/null +++ b/experiments/docred_ner.py @@ -0,0 +1,110 @@ +import argparse +import json +from itertools import accumulate + +from datasets import load_dataset +from openai import OpenAI +from sklearn.metrics import precision_recall_fscore_support + +from nanographrag_tmp._utils import check_and_fix_json + + +LABELS = ["O"] + + +def prepare_docred(split: str, limit: int = 0): + ds = load_dataset("thunlp/docred", split=split) + + def convert(item): + tokens = [] + offsets = [] + acc = 0 + for sent in item["sents"]: + tokens.extend(sent) + offsets.append(acc) + acc += len(sent) + labels = ["O"] * len(tokens) + for entity in item["vertexSet"]: + for mention in entity: + start, end = mention["pos"] + start += offsets[mention["sent_id"]] + end += offsets[mention["sent_id"]] + ent_type = mention.get("type", "MISC") + if f"B-{ent_type}" not in LABELS: + LABELS.extend([f"B-{ent_type}", f"I-{ent_type}"]) + labels[start] = f"B-{ent_type}" + for i in range(start + 1, end): + labels[i] = f"I-{ent_type}" + return {"tokens": tokens, "labels": labels} + + ds = ds.map(convert) + if limit: + ds = ds.select(range(limit)) + return ds + + +def ner_call(client: OpenAI, model: str, tokens: list[str]) -> list[str]: + label_space = ", ".join(LABELS) + messages = [ + {"role": "system", "content": "You are a named entity recognition model."}, + { + "role": "user", + "content": ( + f"Tokens: {tokens}\n" + f"Provide a label from [{label_space}] for each token in order.\n" + f"Return JSON as {{\"labels\": []}}" + ), + }, + ] + completion = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=200, + ) + response = completion.choices[0].message.content + fixed = check_and_fix_json(response) + return json.loads(fixed)["labels"] + + +def evaluate(ds, client: OpenAI, model: str, context_size: int): + y_true = [] + y_pred = [] + for item in ds: + tokens = item["tokens"][:context_size] if context_size > 0 else item["tokens"] + gold = item["labels"][:context_size] if context_size > 0 else item["labels"] + try: + pred = ner_call(client, model, tokens) + except Exception as e: + print(f"Error calling model: {e}") + continue + if len(pred) != len(gold): + continue + y_true.extend(gold) + y_pred.extend(pred) + precision, recall, f1, _ = precision_recall_fscore_support( + y_true, y_pred, labels=LABELS[1:], average="micro" + ) + return precision, recall, f1 + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate NER on DocRED") + parser.add_argument("--models", required=True, help="Comma separated model names") + parser.add_argument("--base_url", default=None) + parser.add_argument("--api_key", default=None) + parser.add_argument("--split", default="validation") + parser.add_argument("--limit", type=int, default=0) + parser.add_argument("--context_sizes", default="0", help="Comma separated token counts") + args = parser.parse_args() + + ds = prepare_docred(args.split, args.limit) + client = OpenAI(api_key=args.api_key, base_url=args.base_url) + context_sizes = [int(x) for x in args.context_sizes.split(",")] + for model in args.models.split(","): + for c in context_sizes: + p, r, f = evaluate(ds, client, model, c) + print(f"Model={model} context={c} precision={p:.3f} recall={r:.3f} f1={f:.3f}") + + +if __name__ == "__main__": + main() diff --git a/experiments/wikiann_ner.py b/experiments/wikiann_ner.py new file mode 100644 index 0000000..b1ae2d2 --- /dev/null +++ b/experiments/wikiann_ner.py @@ -0,0 +1,99 @@ +import argparse +import json + +from datasets import load_dataset +from openai import OpenAI +from sklearn.metrics import precision_recall_fscore_support + +from nanographrag_tmp._utils import check_and_fix_json + +LABELS = [ + "O", + "B-PER", + "I-PER", + "B-ORG", + "I-ORG", + "B-LOC", + "I-LOC", +] + + +def ner_call(client: OpenAI, model: str, tokens: list[str]) -> list[str]: + """Call the model to label tokens. + + The model should return JSON: {"labels": ["..."]} with a label for each token. + """ + label_space = ", ".join(LABELS) + messages = [ + { + "role": "system", + "content": "You are a named entity recognition model." + }, + { + "role": "user", + "content": ( + f"Tokens: {tokens}\n" + f"Provide a label from [{label_space}] for each token in order.\n" + f"Return JSON as {{\"labels\": []}}" + ), + }, + ] + completion = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=200, + ) + response = completion.choices[0].message.content + fixed = check_and_fix_json(response) + return json.loads(fixed)["labels"] + + +def evaluate(ds, label_list, client: OpenAI, model: str, context_size: int): + y_true = [] + y_pred = [] + for item in ds: + tokens = item["tokens"][:context_size] if context_size > 0 else item["tokens"] + labels = item["ner_tags"][:context_size] if context_size > 0 else item["ner_tags"] + gold = [label_list[t] for t in labels] + try: + pred = ner_call(client, model, tokens) + except Exception as e: + print(f"Error calling model: {e}") + continue + if len(pred) != len(gold): + continue + y_true.extend(gold) + y_pred.extend(pred) + precision, recall, f1, _ = precision_recall_fscore_support( + y_true, y_pred, labels=LABELS, average="micro" + ) + return precision, recall, f1 + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate NER on WikiANN") + parser.add_argument("--models", required=True, help="Comma separated model names") + parser.add_argument("--base_url", default=None) + parser.add_argument("--api_key", default=None) + parser.add_argument("--lang", default="en") + parser.add_argument("--split", default="validation") + parser.add_argument("--limit", type=int, default=0) + parser.add_argument("--context_sizes", default="0", help="Comma separated token counts") + args = parser.parse_args() + + ds = load_dataset("wikiann", args.lang, split=args.split) + if args.limit: + ds = ds.select(range(args.limit)) + label_list = ds.features["ner_tags"].feature.names + + client = OpenAI(api_key=args.api_key, base_url=args.base_url) + + context_sizes = [int(x) for x in args.context_sizes.split(",")] + for model in args.models.split(","): + for c in context_sizes: + p, r, f = evaluate(ds, label_list, client, model, c) + print(f"Model={model} context={c} precision={p:.3f} recall={r:.3f} f1={f:.3f}") + + +if __name__ == "__main__": + main()