diff --git a/README.txt b/README.txt new file mode 100644 index 00000000..6facd9cb --- /dev/null +++ b/README.txt @@ -0,0 +1 @@ +Please switch to custom_mention_extractor branch after git submodule update --recursive \ No newline at end of file diff --git a/bootleg/end2end/bootleg_annotator.py b/bootleg/end2end/bootleg_annotator.py index 5ea097bf..4b979182 100644 --- a/bootleg/end2end/bootleg_annotator.py +++ b/bootleg/end2end/bootleg_annotator.py @@ -1,5 +1,10 @@ -"""BootlegAnnotator.""" +"""BootlegAnnotator. + +author shounak.kundu +""" + import logging +from operator import is_ import os import tarfile import urllib @@ -12,12 +17,12 @@ from emmental.model import EmmentalModel from tqdm.auto import tqdm from transformers import AutoTokenizer - from bootleg.dataset import extract_context, get_entity_string from bootleg.end2end.annotator_utils import DownloadProgressBar from bootleg.end2end.extract_mentions import MENTION_EXTRACTOR_OPTIONS from bootleg.symbols.constants import PAD_ID from bootleg.symbols.entity_symbols import EntitySymbols +from bootleg.symbols.entity_profile import EntityProfile from bootleg.symbols.kg_symbols import KGSymbols from bootleg.symbols.type_symbols import TypeSymbols from bootleg.task_config import NED_TASK @@ -49,7 +54,6 @@ def get_default_cache(): ) return Path(torch_cache_home) / "bootleg" - def create_config(model_path, data_path, model_name): """Create Bootleg config. @@ -83,7 +87,6 @@ def create_config(model_path, data_path, model_name): config_args = parse_boot_and_emm_args(config_args) return config_args - def create_sources(model_path, data_path, model_name): """Download Bootleg data and saves in log dir. @@ -241,7 +244,12 @@ def __init__( ), alias_cand_map_dir=self.config.data_config.alias_cand_map, alias_idx_dir=self.config.data_config.alias_idx_map, + edit_mode=False ) + # self.entity_profile = EntityProfile.load_from_cache(\ + # load_dir=self.config.data_config.entity_dir,\ + # no_type=True,edit_mode=False,\ + # verbose=True) self.all_aliases_trie = self.entity_db.get_all_alias_vocabtrie() add_entity_type = self.config.data_config.entity_type_data.use_entity_types @@ -258,14 +266,14 @@ def __init__( add_entity_kg = self.config.data_config.entity_kg_data.use_entity_kg self.kg_symbols = None # If we do not have self.entity_emb_file, then need to generate entity encoder input with metadata - if add_entity_kg and self.entity_emb_file is None: - logger.debug("Reading entity kg database") - self.kg_symbols = KGSymbols.load_from_cache( - os.path.join( - self.config.data_config.entity_dir, - self.config.data_config.entity_kg_data.kg_symbols_dir, - ) + # if add_entity_kg and self.entity_emb_file is None: + logger.debug("Reading entity kg database") + self.kg_symbols = KGSymbols.load_from_cache( + os.path.join( + self.config.data_config.entity_dir, + self.config.data_config.entity_kg_data.kg_symbols_dir, ) + ) logger.debug("Reading word tokenizers") self.tokenizer = AutoTokenizer.from_pretrained( self.config.data_config.word_embedding.bert_model, @@ -311,13 +319,22 @@ def extract_mentions(self, text): Returns: JSON object of sentence to be used in eval """ - found_aliases, found_spans, found_char_spans = MENTION_EXTRACTOR_OPTIONS[ + found_aliases, found_spans,\ + found_char_spans , \ + org_entity_list, \ + per_entity_list, \ + loc_list, \ + type_list = MENTION_EXTRACTOR_OPTIONS[ self.extract_method ](text, self.all_aliases_trie, self.min_alias_len, self.max_alias_len) return { "sentence": text, "aliases": found_aliases, "char_spans": found_char_spans, + "org_entity_list" : org_entity_list, + "per_entity_list" : per_entity_list, + "loc_entity_list" : loc_list, + "type_entity_list": type_list, "cands": [self.entity_db.get_qid_cands(al) for al in found_aliases], # we don't know the true QID "qids": ["Q-1" for i in range(len(found_aliases))], @@ -418,6 +435,12 @@ def label_mentions( batch_char_spans_arr = [] batch_example_aliases = [] batch_idx_unq = [] + + batch_ner_org_list=[] + batch_ner_per_list=[] + batch_ner_loc_list=[] + batch_ner_type_list=[] + for idx_unq in tqdm( range(num_exs), desc="Prepping data", @@ -426,6 +449,10 @@ def label_mentions( ): if do_extract_mentions: sample = self.extract_mentions(text_list[idx_unq]) + batch_ner_org_list.append(sample["org_entity_list"]) + batch_ner_per_list.append(sample["per_entity_list"]) + batch_ner_loc_list.append(sample["loc_entity_list"]) + batch_ner_type_list.append(sample["type_entity_list"]) else: sample = extracted_examples[idx_unq] # Add the unk qids and gold values @@ -523,6 +550,7 @@ def label_mentions( batch_example_true_entities = torch.tensor(batch_example_true_entities) final_pred_cands = [[] for _ in range(num_exs)] + final_pred_cands_types = [[] for _ in range(num_exs)] final_all_cands = [[] for _ in range(num_exs)] final_cand_probs = [[] for _ in range(num_exs)] final_pred_probs = [[] for _ in range(num_exs)] @@ -588,9 +616,23 @@ def label_mentions( pred_prob = max_probs[ex_i].item() pred_qid = entity_cands[pred_idx] if pred_prob > self.threshold: + is_org=False final_all_cands[idx_unq].append(entity_cands) final_cand_probs[idx_unq].append(probs_ex) final_pred_cands[idx_unq].append(pred_qid) + entity_relation_dict=self.kg_symbols.get_relations_tails_for_qid(pred_qid) + if 'instance of' in entity_relation_dict: + instance_of_list = entity_relation_dict['instance of'] + if 'Q5' not in instance_of_list: ## Q5 means human + is_org=True + elif 'place of birth' in entity_relation_dict: + is_org=False + else: + is_org=True + if is_org: + final_pred_cands_types[idx_unq].append("ORG") + else: + final_pred_cands_types[idx_unq].append("PER") final_pred_probs[idx_unq].append(pred_prob) if self.return_embs: final_entity_embs[idx_unq].append( @@ -617,10 +659,15 @@ def label_mentions( "qids": final_pred_cands, "probs": final_pred_probs, "titles": final_titles, + "qid_types": final_pred_cands_types, "cands": final_all_cands, "cand_probs": final_cand_probs, "char_spans": final_char_spans, "aliases": final_aliases, + "org_entity_list": batch_ner_org_list, + "per_entity_list": batch_ner_per_list, + "loc_entity_list": batch_ner_loc_list, + "type_entity_list": batch_ner_type_list } if self.return_embs: res_dict["embs"] = final_entity_embs diff --git a/bootleg/utils/mention_extractor_utils.py b/bootleg/utils/mention_extractor_utils.py index 762a8980..a5454465 100644 --- a/bootleg/utils/mention_extractor_utils.py +++ b/bootleg/utils/mention_extractor_utils.py @@ -2,7 +2,7 @@ import string from collections import namedtuple from typing import List, Tuple, Union - +import torch import nltk import spacy from spacy.cli.download import download as spacy_download @@ -10,6 +10,7 @@ from bootleg.symbols.constants import LANG_CODE from bootleg.utils.utils import get_lnrm + logger = logging.getLogger(__name__) span_tuple = namedtuple("Span", ["text", "start_char_idx", "end_char_idx"]) @@ -25,10 +26,13 @@ except OSError: nlp = None +DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' + try: + import flair from flair.data import Sentence from flair.models import SequenceTagger - + flair.device = torch.device(DEVICE) tagger_fast = SequenceTagger.load("ner-ontonotes-fast") except ImportError: tagger_fast = None @@ -49,7 +53,7 @@ #"NORP", "ORG", "GPE", - "LOC", + #"LOC", #"PRODUCT", #"EVENT", #"WORK_OF_ART", @@ -288,8 +292,14 @@ def my_mention_extractor( """ sentence = Sentence(text) - tagger_fast.predict(sentence, mini_batch_size=16) + tagger_fast.predict(sentence) entities = [] + org_entities=[] + per_entities=[] + loc_entities=[] + type_entities=[] + is_org=False + is_per=False for i in range(len(sentence.to_dict(tag_type="ner")["entities"])): str_main = None start_pos = -1 @@ -301,6 +311,7 @@ def my_mention_extractor( str_main = str(sentence.to_dict(tag_type="ner")["entities"][i]["text"]) start_pos = sentence.to_dict(tag_type="ner")["entities"][i]["start_pos"] end_pos = sentence.to_dict(tag_type="ner")["entities"][i]["end_pos"] + is_org=True elif ( str(sentence.to_dict(tag_type="ner")["entities"][i]["labels"][0]).split()[0] @@ -309,14 +320,21 @@ def my_mention_extractor( str_main = str(sentence.to_dict(tag_type="ner")["entities"][i]["text"]) start_pos = sentence.to_dict(tag_type="ner")["entities"][i]["start_pos"] end_pos = sentence.to_dict(tag_type="ner")["entities"][i]["end_pos"] + is_per=True elif ( str(sentence.to_dict(tag_type="ner")["entities"][i]["labels"][0]).split()[0] - in "GPE" - ): - str_main = str(sentence.to_dict(tag_type="ner")["entities"][i]["text"]) - start_pos = sentence.to_dict(tag_type="ner")["entities"][i]["start_pos"] - end_pos = sentence.to_dict(tag_type="ner")["entities"][i]["end_pos"] + in "GPE"): + loc_value = str(sentence.to_dict(tag_type="ner")["entities"][i]["text"]) + loc_entities.append(loc_value) + if is_org: + org_text_entity = str_main + org_entities.append(org_text_entity) + type_entities.append(["ORG",start_pos,end_pos]) + if is_per: + per_text_entity = str_main + per_entities.append(per_text_entity) + type_entities.append(["PER",start_pos,end_pos]) if str_main is not None and (start_pos != -1 and end_pos != -1): final_gram = None if str_main in all_aliases: @@ -331,8 +349,10 @@ def my_mention_extractor( final_gram = joined_gram_merged_noplural if final_gram is not None: entities.append([final_gram, start_pos, end_pos]) + is_org=False + is_per=False used_aliases = [item[0] for item in entities] chars = [[item[1], item[2]] for item in entities] spans = [[len(text[: sp[0]].split()), len(text[: sp[1]].split())] for sp in chars] - return used_aliases, spans, chars + return used_aliases, spans, chars , org_entities, per_entities , loc_entities , type_entities