Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,7 @@ dmypy.json
.pyre/

# mac
.DS_Store
.DS_Store

# dw test
/test
3 changes: 2 additions & 1 deletion config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mlp_hidden_size: 150
max_span_length: 10
dropout: 0.4
logit_dropout: 0.2
bert_model_name: bert-base-uncased
bert_model_name: /mnt/data1/public/pretrain/bert-base-uncased
bert_output_size: 0
bert_dropout: 0.0
separate_threshold: 1.4
Expand All @@ -39,3 +39,4 @@ logging_steps: 32
validate_every: 20000
device: -1
log_file: train.log
eval_type: train # debug
6 changes: 3 additions & 3 deletions data/entity-relation/ACE2004/ace2004.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ python ../transfer.py transfer fold1/00-raw/train.json fold1/01-change-fields/tr
python ../transfer.py transfer fold1/00-raw/dev.json fold1/01-change-fields/dev.json [PER-SOC]
python ../transfer.py transfer fold1/00-raw/test.json fold1/01-change-fields/test.json [PER-SOC]

python ../process.py process fold1/01-change-fields/train.json fold1/ent_rel_file.json fold1/02-matrix/train.json bert-base-uncased 200
python ../process.py process fold1/01-change-fields/dev.json fold1/ent_rel_file.json fold1/02-matrix/dev.json bert-base-uncased 200
python ../process.py process fold1/01-change-fields/test.json fold1/ent_rel_file.json fold1/02-matrix/test.json bert-base-uncased 200
python ../process.py process fold1/01-change-fields/train.json fold1/ent_rel_file.json fold1/02-matrix/train.json bert-base-uncased 200 True
python ../process.py process fold1/01-change-fields/dev.json fold1/ent_rel_file.json fold1/02-matrix/dev.json bert-base-uncased 200 True
python ../process.py process fold1/01-change-fields/test.json fold1/ent_rel_file.json fold1/02-matrix/test.json bert-base-uncased 200 True
8 changes: 4 additions & 4 deletions data/entity-relation/ACE2005/ace2005.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ python ../transfer.py transfer 00-raw/train.json 01-change-fields/train.json [PE
python ../transfer.py transfer 00-raw/dev.json 01-change-fields/dev.json [PER-SOC]
python ../transfer.py transfer 00-raw/test.json 01-change-fields/test.json [PER-SOC]

python ../process.py process 01-change-fields/train.json ent_rel_file.json 02-matrix/train.json bert-base-uncased 200-raw
python ../process.py process 01-change-fields/dev.json ent_rel_file.json 02-matrix/dev.json bert-base-uncased 200-raw
python ../process.py process 01-change-fields/test.json ent_rel_file.json 02-matrix/test.json bert-base-uncased 200-raw
python ../process.py process 01-change-fields/train.json ent_rel_file.json 02-matrix/train.json /mnt/data1/public/pretrain/bert-base-uncased 200 True
python ../process.py process 01-change-fields/dev.json ent_rel_file.json 02-matrix/dev.json /mnt/data1/public/pretrain/bert-base-uncased 200 True
python ../process.py process 01-change-fields/test.json ent_rel_file.json 02-matrix/test.json /mnt/data1/public/pretrain/bert-base-uncased 200 True

#rm -rf 01-change-fields
#rm -rf 01-change-fields
6 changes: 3 additions & 3 deletions data/entity-relation/SciERC/scierc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ python ../transfer.py transfer 00-raw/train.json 01-change-fields/train.json [CO
python ../transfer.py transfer 00-raw/dev.json 01-change-fields/dev.json [COMPARE,CONJUNCTION]
python ../transfer.py transfer 00-raw/test.json 01-change-fields/test.json [COMPARE,CONJUNCTION]

python ../process.py process 01-change-fields/train.json ent_rel_file.json 02-matrix/train.json allenai/scibert_scivocab_uncased 200-raw
python ../process.py process 01-change-fields/dev.json ent_rel_file.json 02-matrix/dev.json allenai/scibert_scivocab_uncased 200-raw
python ../process.py process 01-change-fields/test.json ent_rel_file.json 02-matrix/test.json allenai/scibert_scivocab_uncased 200-raw
python ../process.py process 01-change-fields/train.json ent_rel_file.json 02-matrix/train.json allenai/scibert_scivocab_uncased 200 True
python ../process.py process 01-change-fields/dev.json ent_rel_file.json 02-matrix/dev.json allenai/scibert_scivocab_uncased 200 True
python ../process.py process 01-change-fields/test.json ent_rel_file.json 02-matrix/test.json allenai/scibert_scivocab_uncased 200 True
44 changes: 15 additions & 29 deletions data/entity-relation/process.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import json
import fire

Expand Down Expand Up @@ -93,18 +94,12 @@ def add_joint_label(sent, ent_rel_info):
ent2offset = {}
for ent in sent['entityMentions']:
ent2offset[ent['emId']] = ent['offset']
label_matrix[ent['offset'][0]: ent['offset'][1]][ent['offset'][0]: ent['offset'][1]] = ent_rel_id[ent['label']]

label_matrix[ent['offset'][0]: ent['offset'][1]][:, ent['offset'][0]: ent['offset'][1]] = ent_rel_id[ent['label']]
for rel in sent['relationMentions']:
label_matrix[ent2offset[rel['em1Id']][0]: ent2offset[rel['em1Id']][1]][ent2offset[rel['em2Id']][0]: ent2offset[rel['em2Id']][1]] = ent_rel_id[rel['label']]
label_matrix[ent2offset[rel['em1Id']][0]: ent2offset[rel['em1Id']][1]][:, ent2offset[rel['em2Id']][0]: ent2offset[rel['em2Id']][1]] = ent_rel_id[rel['label']]
if ent_rel_id[rel['label']] in ent_rel_info['symmetric']:
label_matrix[ent2offset[ent2offset[rel['em2Id']][0]: ent2offset[rel['em2Id']][1]][rel['em1Id']][0]: ent2offset[rel['em1Id']][1]] = ent_rel_id[rel['label']]

# for i in range(ent2offset[rel['em1Id']][0], ent2offset[rel['em1Id']][1]):
# for j in range(ent2offset[rel['em2Id']][0], ent2offset[rel['em2Id']][1]):
# label_matrix[i][j] = ent_rel_id[rel['label']]
# if ent_rel_id[rel['label']] in ent_rel_info['symmetric']:
# label_matrix[j][i] = ent_rel_id[rel['label']]
label_matrix[ent2offset[rel['em2Id']][0]: ent2offset[rel['em2Id']][1]][:, ent2offset[rel['em1Id']][0]: ent2offset[rel['em1Id']][1]] = ent_rel_id[rel['label']]

sent['jointLabelMatrix'] = label_matrix.tolist()

Expand All @@ -130,17 +125,10 @@ def add_joint_label_with_BItag(sent, ent_rel_info):
ent['label'] = "B-" + ent['label']

for rel in sent['relationMentions']:
# for i in range(ent2offset[rel['em1Id']][0], ent2offset[rel['em1Id']][1]):
# for j in range(ent2offset[rel['em2Id']][0], ent2offset[rel['em2Id']][1]):
# #assert label_matrix[i][j] == 0, "Exist relation overlapping!"
# label_matrix[i][j] = ent_rel_info['id'][rel['label']]
# if ent_rel_info['id'][rel['label']] in ent_rel_info['symmetric']:
# label_matrix[j][i] = ent_rel_info['id'][rel['label']]

label_matrix[ent2offset[rel['em1Id']][0]: ent2offset[rel['em1Id']][1]][ent2offset[rel['em2Id']][0]: ent2offset[rel['em2Id']][1]] = ent_rel_info['id'][rel['label']]
label_matrix[ent2offset[rel['em1Id']][0]: ent2offset[rel['em1Id']][1]][:, ent2offset[rel['em2Id']][0]: ent2offset[rel['em2Id']][1]] = ent_rel_info['id'][rel['label']]
if ent_rel_info['id'][rel['label']] in ent_rel_info['symmetric']:
label_matrix[ent2offset[ent2offset[rel['em2Id']][0]: ent2offset[rel['em2Id']][1]][rel['em1Id']][0]: ent2offset[rel['em1Id']][1]] = ent_rel_info['id'][rel['label']]

label_matrix[ent2offset[rel['em2Id']][0]: ent2offset[rel['em2Id']][1]][:, ent2offset[rel['em1Id']][0]: ent2offset[rel['em1Id']][1]] = ent_rel_info['id'][rel['label']]
sent['jointLabelMatrix'] = label_matrix.tolist()

def add_wordpiece_fields(sent, tokenizer):
Expand All @@ -167,7 +155,7 @@ def add_wordpiece_fields(sent, tokenizer):
wordpiece_segment_ids = [0] * len(wordpiece_tokens)
assert len(wordpiece_tokens) == len(wordpiece_segment_ids)

sent['wordpiece_tokens'] = wordpiece_tokens
sent['wordpieceSentText'] = wordpiece_tokens
sent['wordpieceTokensIndex'] = wordpiece_tokens_index
sent['wordpieceSegmentIds'] = wordpiece_segment_ids
return sent
Expand Down Expand Up @@ -248,7 +236,7 @@ def get_ent_rel_file(ent_rel_file, data_file_path, data_parts=['train.json', 'de

print(json.dumps(ins, ensure_ascii=False), file=fout)

def process(source_file, ent_rel_file, target_file, pretrained_model, max_length=200, standard=True):
def process(source_file, ent_rel_file, target_file, pretrained_model, max_length=200, cross_sentence=False):
auto_tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
print("Load {} tokenizer successfully.".format(pretrained_model))

Expand All @@ -259,14 +247,12 @@ def process(source_file, ent_rel_file, target_file, pretrained_model, max_length
with open(ent_rel_file, 'r', encoding='utf-8') as f:
ent_rel_info = json.load(f)

if not os.path.exists(os.path.dirname(target_file)):
os.mkdir(os.path.dirname(target_file))

with open(source_file, 'r', encoding='utf-8') as fin, open(target_file, 'w', encoding='utf-8') as fout:
# given datasets should conform to the standard setting, such as ACE2005, SciERC
if standard:
if cross_sentence:
sentences = []
for line in fin:
for i, line in enumerate(fin):
print(f"Process Standard dataset Line{i + 1}")
sent = json.loads(line.strip())

if len(sentences) == 0 or sentences[0]['articleId'] == sent['articleId']:
Expand All @@ -284,11 +270,11 @@ def process(source_file, ent_rel_file, target_file, pretrained_model, max_length
# processing other datasets
else:
for i, line in enumerate(fin):
print(f"Process Line{i + 1}")
print(f"Process other dataset Line{i + 1}")
sent = json.loads(line.strip())

add_wordpiece_fields(sent, auto_tokenizer)
add_joint_label(sent, ent_rel_id)
add_joint_label(sent, ent_rel_info)

print(json.dumps(sent, ensure_ascii=False), file=fout)

Expand Down
Empty file modified data/event/ChFinAnn/01-coref/coref_resolv.py
100755 → 100644
Empty file.
Empty file modified data/event/ChFinAnn/02-sample/sample.py
100755 → 100644
Empty file.
Empty file modified data/event/ChFinAnn/03-evt-rel/run.sh
100755 → 100644
Empty file.
3 changes: 2 additions & 1 deletion data/event/ChFinAnn/scripts/eval.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def main():
cmd_parser.add_argument("gold_file", help="gold file")
cmd_parser.add_argument("-p", "--overlap_propotion", metavar="FLOAT",
help="a predicted span is correct if it overlaps with p of a gold span (default 1.0)",
type=float,
default=1.0)
cmd_parser.add_argument("-t", "--filter_type", metavar="TYPE",
help="ignore entities belong to those types during evaluation (default None)",
Expand All @@ -157,6 +158,6 @@ def main():

cmds = cmd_parser.parse_args()

logging.basicConfig(filename=os.path.basename(cmds.pred_file)+'_evt.res',
logging.basicConfig(filename=os.path.basename(cmds.pred_file)+'.evt.log',
encoding='utf-8', level=logging.DEBUG)
main()
Empty file modified data/event/ChFinAnn/scripts/evt2evt.py
100755 → 100644
Empty file.
Empty file modified data/event/ChFinAnn/scripts/evt2rel.py
100755 → 100644
Empty file.
Empty file modified data/event/ChFinAnn/scripts/prettyjson.py
100755 → 100644
Empty file.
Empty file modified data/event/ChFinAnn/scripts/rel2evt.py
100755 → 100644
Empty file.
41 changes: 30 additions & 11 deletions entity_relation_joint_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@

from transformers import BertTokenizer, AutoTokenizer, AdamW, get_linear_schedule_with_warmup

# origin unire eval
from eval.prediction_outputs_old import print_predictions_for_joint_decoding as print_predictions_for_debug
from eval.eval_old import eval_file as eval_file_for_debug

# eval
from eval.prediction_outputs import print_predictions_for_joint_decoding
from eval.eval import eval_file

from utils.argparse import ConfigurationParer
from utils.prediction_outputs import print_predictions_for_joint_decoding
from utils.eval import eval_file
from inputs.vocabulary import Vocabulary
from inputs.fields.token_field import TokenField
from inputs.fields.raw_token_field import RawTokenField
Expand Down Expand Up @@ -145,7 +151,7 @@ def train(cfg, dataset, model):
best_f1 = dev_f1
logger.info("Save model...")
#torch.save(model.state_dict(), open(cfg.best_model_path, "wb"))
torch.save(model.state_dict())
torch.save(model.state_dict(), cfg.best_model_path)
elif last_epoch != epoch:
early_stop_cnt += 1
if early_stop_cnt > cfg.early_stop:
Expand Down Expand Up @@ -203,9 +209,17 @@ def dev(cfg, dataset, model):
all_outputs.extend(batch_outpus)
logger.info(f"Cost time: {cost_time}s")
dev_output_file = os.path.join(cfg.save_dir, "dev.output")
print_predictions_for_joint_decoding(all_outputs, dev_output_file, dataset.vocab)
eval_metrics = ['joint-label', 'separate-position', 'ent', 'exact-rel', 'overlap-rel']
joint_label_score, separate_position_score, ent_score, exact_rel_score, overlap_rel_score = eval_file(dev_output_file, eval_metrics)
dev_gold_file = os.path.join(cfg.save_dir, "dev_gold.output")

if cfg.eval_type == "debug":
# old unire eval
print_predictions_for_debug(all_outputs, dev_output_file, dataset.vocab)
eval_metrics = ['joint-label', 'separate-position', 'ent', 'exact-rel', 'overlap-rel']
joint_label_score, separate_position_score, ent_score, exact_rel_score, overlap_rel_score = eval_file_for_debug(dev_output_file, eval_metrics, cfg)
else:
print_predictions_for_joint_decoding(all_outputs, dev_output_file, dev_gold_file, dataset.vocab)
ent_score, exact_rel_score = eval_file(dev_output_file, dev_gold_file, entity_metric=["exact"], relation_metric=["exact"])

return ent_score + exact_rel_score


Expand All @@ -226,10 +240,16 @@ def test(cfg, dataset, model):
logger.info(f"Cost time: {cost_time}s")

test_output_file = os.path.join(cfg.save_dir, "test.output")
print_predictions_for_joint_decoding(all_outputs, test_output_file, dataset.vocab)
eval_metrics = ['joint-label', 'separate-position', 'ent', 'exact-rel', 'overlap-rel']
eval_file(test_output_file, eval_metrics)

test_gold_file = os.path.join(cfg.save_dir, "test_gold.output")

if cfg.eval_type == "debug":
# old unire eval
print_predictions_for_debug(all_outputs, test_output_file, dataset.vocab)
eval_metrics = ['joint-label', 'separate-position', 'ent', 'exact-rel', 'overlap-rel']
eval_file_for_debug(test_output_file, eval_metrics, cfg)
else:
print_predictions_for_joint_decoding(all_outputs, test_output_file, test_gold_file, dataset.vocab)
eval_file(test_output_file, test_gold_file, entity_metric=["exact"], relation_metric=["exact"])

def main():
# config settings
Expand Down Expand Up @@ -324,7 +344,6 @@ def main():
model = EntRelJointDecoder(cfg=cfg, vocab=vocab, ent_rel_file=ent_rel_file)

if cfg.test and os.path.exists(cfg.best_model_path):
#state_dict = torch.load(open(cfg.best_model_path, 'rb'), map_location=lambda storage, loc: storage)
state_dict = torch.load(cfg.best_model_path)
model.load_state_dict(state_dict)
logger.info("Loading best training model {} successfully for testing.".format(cfg.best_model_path))
Expand Down
64 changes: 64 additions & 0 deletions eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Evaluation


## eval.py

A (stand alone) script for evaluating entity relation extraction.

```shell
usage: eval.py [-h] [-p FLOAT] [-e [TYPE ...]] [-r [TYPE ...]] pred_file gold_file

a (stand alone) script for evaluating entity relation extraction

positional arguments:
pred_file predict file
gold_file gold file

options:
-h, --help show this help message and exit
-p FLOAT, --overlap_propotion FLOAT
a predicted span is correct if it overlaps with p of a gold span (default 1.0)
-e [TYPE ...], --entity_metrics [TYPE ...]
criteria for evaluating correctness of entities.
exact: accept entities with correct type and offset
offset: ignore entity type, only match entity offset
string: accept entities with correct string
overlap: accept entities with overlapped string (combine with -p option)
(default ALL)
-r [TYPE ...], --relation_metrics [TYPE ...]
criteria for evaluating correctness of relations.
exact: accept relations with exact entities match (type and offset)
string: accept relations with correct entity strings
(default ALL)
```


The `pred_file` and `gold_file` are in the following jsonline format,
```json
{
"id": 0,
"text": "abcdefg",
"entity": [
{"ent_id": 0, "type": "ent_type_1", "offset": [0, 1], "text": "a"},
{"ent_id": 1, "type": "ent_type_2", "offset": [1, 2], "text": "b"},
{"ent_id": 2, "type": "ent_type_1", "offset": [2, 4], "text": "cd"},
{"ent_id": 3, "type": "ent_type_2", "offset": [4, 5], "text": "e"},
{"ent_id": 4, "type": "ent_type_3", "offset": [5, 7], "text": "fg"}
],
"relation": [{
"type": "rel_type_1",
"args": [0, 4]
},
{
"type": "rel_type_2",
"args": [1, 3]
}]
}
```
Arguments of relations are specified by their "ent_id".


## eval_old.py

It is the script used in the original UniRE paper.
We keep it for backword compatibility and will remove it later.
Loading