Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Please switch to custom_mention_extractor branch after git submodule update --recursive
71 changes: 59 additions & 12 deletions bootleg/end2end/bootleg_annotator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""BootlegAnnotator."""
"""BootlegAnnotator.

author shounak.kundu
"""

import logging
from operator import is_
import os
import tarfile
import urllib
Expand All @@ -12,12 +17,12 @@
from emmental.model import EmmentalModel
from tqdm.auto import tqdm
from transformers import AutoTokenizer

from bootleg.dataset import extract_context, get_entity_string
from bootleg.end2end.annotator_utils import DownloadProgressBar
from bootleg.end2end.extract_mentions import MENTION_EXTRACTOR_OPTIONS
from bootleg.symbols.constants import PAD_ID
from bootleg.symbols.entity_symbols import EntitySymbols
from bootleg.symbols.entity_profile import EntityProfile
from bootleg.symbols.kg_symbols import KGSymbols
from bootleg.symbols.type_symbols import TypeSymbols
from bootleg.task_config import NED_TASK
Expand Down Expand Up @@ -49,7 +54,6 @@ def get_default_cache():
)
return Path(torch_cache_home) / "bootleg"


def create_config(model_path, data_path, model_name):
"""Create Bootleg config.

Expand Down Expand Up @@ -83,7 +87,6 @@ def create_config(model_path, data_path, model_name):
config_args = parse_boot_and_emm_args(config_args)
return config_args


def create_sources(model_path, data_path, model_name):
"""Download Bootleg data and saves in log dir.

Expand Down Expand Up @@ -241,7 +244,12 @@ def __init__(
),
alias_cand_map_dir=self.config.data_config.alias_cand_map,
alias_idx_dir=self.config.data_config.alias_idx_map,
edit_mode=False
)
# self.entity_profile = EntityProfile.load_from_cache(\
# load_dir=self.config.data_config.entity_dir,\
# no_type=True,edit_mode=False,\
# verbose=True)
self.all_aliases_trie = self.entity_db.get_all_alias_vocabtrie()

add_entity_type = self.config.data_config.entity_type_data.use_entity_types
Expand All @@ -258,14 +266,14 @@ def __init__(
add_entity_kg = self.config.data_config.entity_kg_data.use_entity_kg
self.kg_symbols = None
# If we do not have self.entity_emb_file, then need to generate entity encoder input with metadata
if add_entity_kg and self.entity_emb_file is None:
logger.debug("Reading entity kg database")
self.kg_symbols = KGSymbols.load_from_cache(
os.path.join(
self.config.data_config.entity_dir,
self.config.data_config.entity_kg_data.kg_symbols_dir,
)
# if add_entity_kg and self.entity_emb_file is None:
logger.debug("Reading entity kg database")
self.kg_symbols = KGSymbols.load_from_cache(
os.path.join(
self.config.data_config.entity_dir,
self.config.data_config.entity_kg_data.kg_symbols_dir,
)
)
logger.debug("Reading word tokenizers")
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.data_config.word_embedding.bert_model,
Expand Down Expand Up @@ -311,13 +319,22 @@ def extract_mentions(self, text):

Returns: JSON object of sentence to be used in eval
"""
found_aliases, found_spans, found_char_spans = MENTION_EXTRACTOR_OPTIONS[
found_aliases, found_spans,\
found_char_spans , \
org_entity_list, \
per_entity_list, \
loc_list, \
type_list = MENTION_EXTRACTOR_OPTIONS[
self.extract_method
](text, self.all_aliases_trie, self.min_alias_len, self.max_alias_len)
return {
"sentence": text,
"aliases": found_aliases,
"char_spans": found_char_spans,
"org_entity_list" : org_entity_list,
"per_entity_list" : per_entity_list,
"loc_entity_list" : loc_list,
"type_entity_list": type_list,
"cands": [self.entity_db.get_qid_cands(al) for al in found_aliases],
# we don't know the true QID
"qids": ["Q-1" for i in range(len(found_aliases))],
Expand Down Expand Up @@ -418,6 +435,12 @@ def label_mentions(
batch_char_spans_arr = []
batch_example_aliases = []
batch_idx_unq = []

batch_ner_org_list=[]
batch_ner_per_list=[]
batch_ner_loc_list=[]
batch_ner_type_list=[]

for idx_unq in tqdm(
range(num_exs),
desc="Prepping data",
Expand All @@ -426,6 +449,10 @@ def label_mentions(
):
if do_extract_mentions:
sample = self.extract_mentions(text_list[idx_unq])
batch_ner_org_list.append(sample["org_entity_list"])
batch_ner_per_list.append(sample["per_entity_list"])
batch_ner_loc_list.append(sample["loc_entity_list"])
batch_ner_type_list.append(sample["type_entity_list"])
else:
sample = extracted_examples[idx_unq]
# Add the unk qids and gold values
Expand Down Expand Up @@ -523,6 +550,7 @@ def label_mentions(
batch_example_true_entities = torch.tensor(batch_example_true_entities)

final_pred_cands = [[] for _ in range(num_exs)]
final_pred_cands_types = [[] for _ in range(num_exs)]
final_all_cands = [[] for _ in range(num_exs)]
final_cand_probs = [[] for _ in range(num_exs)]
final_pred_probs = [[] for _ in range(num_exs)]
Expand Down Expand Up @@ -588,9 +616,23 @@ def label_mentions(
pred_prob = max_probs[ex_i].item()
pred_qid = entity_cands[pred_idx]
if pred_prob > self.threshold:
is_org=False
final_all_cands[idx_unq].append(entity_cands)
final_cand_probs[idx_unq].append(probs_ex)
final_pred_cands[idx_unq].append(pred_qid)
entity_relation_dict=self.kg_symbols.get_relations_tails_for_qid(pred_qid)
if 'instance of' in entity_relation_dict:
instance_of_list = entity_relation_dict['instance of']
if 'Q5' not in instance_of_list: ## Q5 means human
is_org=True
elif 'place of birth' in entity_relation_dict:
is_org=False
else:
is_org=True
if is_org:
final_pred_cands_types[idx_unq].append("ORG")
else:
final_pred_cands_types[idx_unq].append("PER")
final_pred_probs[idx_unq].append(pred_prob)
if self.return_embs:
final_entity_embs[idx_unq].append(
Expand All @@ -617,10 +659,15 @@ def label_mentions(
"qids": final_pred_cands,
"probs": final_pred_probs,
"titles": final_titles,
"qid_types": final_pred_cands_types,
"cands": final_all_cands,
"cand_probs": final_cand_probs,
"char_spans": final_char_spans,
"aliases": final_aliases,
"org_entity_list": batch_ner_org_list,
"per_entity_list": batch_ner_per_list,
"loc_entity_list": batch_ner_loc_list,
"type_entity_list": batch_ner_type_list
}
if self.return_embs:
res_dict["embs"] = final_entity_embs
Expand Down
40 changes: 30 additions & 10 deletions bootleg/utils/mention_extractor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import string
from collections import namedtuple
from typing import List, Tuple, Union

import torch
import nltk
import spacy
from spacy.cli.download import download as spacy_download

from bootleg.symbols.constants import LANG_CODE
from bootleg.utils.utils import get_lnrm


logger = logging.getLogger(__name__)

span_tuple = namedtuple("Span", ["text", "start_char_idx", "end_char_idx"])
Expand All @@ -25,10 +26,13 @@
except OSError:
nlp = None

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

try:
import flair
from flair.data import Sentence
from flair.models import SequenceTagger

flair.device = torch.device(DEVICE)
tagger_fast = SequenceTagger.load("ner-ontonotes-fast")
except ImportError:
tagger_fast = None
Expand All @@ -49,7 +53,7 @@
#"NORP",
"ORG",
"GPE",
"LOC",
#"LOC",
#"PRODUCT",
#"EVENT",
#"WORK_OF_ART",
Expand Down Expand Up @@ -288,8 +292,14 @@ def my_mention_extractor(
"""

sentence = Sentence(text)
tagger_fast.predict(sentence, mini_batch_size=16)
tagger_fast.predict(sentence)
entities = []
org_entities=[]
per_entities=[]
loc_entities=[]
type_entities=[]
is_org=False
is_per=False
for i in range(len(sentence.to_dict(tag_type="ner")["entities"])):
str_main = None
start_pos = -1
Expand All @@ -301,6 +311,7 @@ def my_mention_extractor(
str_main = str(sentence.to_dict(tag_type="ner")["entities"][i]["text"])
start_pos = sentence.to_dict(tag_type="ner")["entities"][i]["start_pos"]
end_pos = sentence.to_dict(tag_type="ner")["entities"][i]["end_pos"]
is_org=True

elif (
str(sentence.to_dict(tag_type="ner")["entities"][i]["labels"][0]).split()[0]
Expand All @@ -309,14 +320,21 @@ def my_mention_extractor(
str_main = str(sentence.to_dict(tag_type="ner")["entities"][i]["text"])
start_pos = sentence.to_dict(tag_type="ner")["entities"][i]["start_pos"]
end_pos = sentence.to_dict(tag_type="ner")["entities"][i]["end_pos"]
is_per=True

elif (
str(sentence.to_dict(tag_type="ner")["entities"][i]["labels"][0]).split()[0]
in "GPE"
):
str_main = str(sentence.to_dict(tag_type="ner")["entities"][i]["text"])
start_pos = sentence.to_dict(tag_type="ner")["entities"][i]["start_pos"]
end_pos = sentence.to_dict(tag_type="ner")["entities"][i]["end_pos"]
in "GPE"):
loc_value = str(sentence.to_dict(tag_type="ner")["entities"][i]["text"])
loc_entities.append(loc_value)
if is_org:
org_text_entity = str_main
org_entities.append(org_text_entity)
type_entities.append(["ORG",start_pos,end_pos])
if is_per:
per_text_entity = str_main
per_entities.append(per_text_entity)
type_entities.append(["PER",start_pos,end_pos])
if str_main is not None and (start_pos != -1 and end_pos != -1):
final_gram = None
if str_main in all_aliases:
Expand All @@ -331,8 +349,10 @@ def my_mention_extractor(
final_gram = joined_gram_merged_noplural
if final_gram is not None:
entities.append([final_gram, start_pos, end_pos])
is_org=False
is_per=False

used_aliases = [item[0] for item in entities]
chars = [[item[1], item[2]] for item in entities]
spans = [[len(text[: sp[0]].split()), len(text[: sp[1]].split())] for sp in chars]
return used_aliases, spans, chars
return used_aliases, spans, chars , org_entities, per_entities , loc_entities , type_entities