Skip to content

Commit 9e52021

Browse files
committed
Add original loading of LXMERT (Resolve #6)
1 parent 2b20a26 commit 9e52021

4 files changed

Lines changed: 174 additions & 2 deletions

File tree

MODELS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ For the latter, we distribute the weights that lead to higher average downstream
1818
| [VisualBERT (CTRL)](https://sid.erda.dk/share_redirect/GCBlzUuoJl) | 69.03 | 70.02 | 72.70 | 61.48 | 75.20 |
1919
| [UNITER (CTRL)](https://sid.erda.dk/share_redirect/FeYIWpMSFg) | 68.67 | 71.45 | 73.73 | 60.54 | 76.40 |
2020

21+
### Conversions of Original Models into VOLTA
22+
| Model | Source |
23+
|-------------------|--------|
24+
| [LXMERT (Original)](https://sid.erda.dk/share_redirect/cFGANaAtmN) | [airsplay/lxmert](https://nlp.cs.unc.edu/data/github_pretrain/lxmert20/Epoch20_LXRT.pth) |
25+
2126

2227
## Models Definition
2328

config/original_lxmert.json

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
{
2+
"attention_probs_dropout_prob": 0.1,
3+
"hidden_act": "gelu",
4+
"hidden_dropout_prob": 0.1,
5+
"hidden_size": 768,
6+
"initializer_range": 0.02,
7+
"intermediate_size": 3072,
8+
"max_position_embeddings": 512,
9+
"num_attention_heads": 12,
10+
"pooler_size": 768,
11+
"type_vocab_size": 2,
12+
"vocab_size": 30522,
13+
"bert_model": "bert-base-uncased",
14+
"do_lower_case": true,
15+
"num_locs": 4,
16+
"image_embeddings": "lxmert",
17+
"v_attention_probs_dropout_prob": 0.1,
18+
"v_hidden_act": "gelu",
19+
"v_hidden_dropout_prob": 0.1,
20+
"v_feature_size": 2048,
21+
"visual_target_weights": {"3": 6.667, "4": 6.667, "5": 6.667},
22+
"qa_task_weight": 1,
23+
"qa_num_answers": 9500,
24+
"v_hidden_size": 768,
25+
"v_initializer_range": 0.02,
26+
"v_num_attention_heads": 12,
27+
"v_intermediate_size": 3072,
28+
"fusion_method": "text",
29+
"clf_hidden_size": 1536,
30+
"tt_attn_sublayers": [0,2,4,6,8,10,12,14,16,19,22,25,28,31],
31+
"tv_attn_sublayers": [18,21,24,27,30],
32+
"vt_attn_sublayers": [18,21,24,27,30],
33+
"vv_attn_sublayers": [0,2,4,6,8,19,22,25,28,31],
34+
"t_ff_sublayers": [1,3,5,7,9,11,13,15,17,20,23,26,29,32],
35+
"v_ff_sublayers": [1,3,5,7,9,20,23,26,29,32],
36+
"shared_sublayers": [18,21,24,27,30],
37+
"single_ln_sublayers": [],
38+
"sublayer2attn_hidden_size": {},
39+
"sublayer2num_attention_heads": {},
40+
"sublayer2intermediate_size": {},
41+
"sublayer2v_attn_hidden_size": {},
42+
"sublayer2v_num_attention_heads": {},
43+
"sublayer2v_intermediate_size": {},
44+
"bert_layer2attn_sublayer": {
45+
"0": 0, "1": 2, "2": 4, "3": 6, "4": 8, "5": 10,
46+
"6": 12, "7": 14, "8": 16, "9": 19, "10": 22, "11": 25
47+
},
48+
"bert_layer2ff_sublayer": {
49+
"0": 1, "1": 3, "2": 5, "3": 7, "4": 9, "5": 11,
50+
"6": 13, "7": 15, "8": 17, "9": 20, "10": 23, "11": 26
51+
}
52+
}

conversions/convert_lxmert.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import os
2+
import sys
3+
import argparse
4+
sys.path.append("../")
5+
6+
import torch
7+
from volta.config import BertConfig
8+
from volta.encoders import BertForVLPreTraining
9+
10+
11+
# Inputs
12+
parser = argparse.ArgumentParser()
13+
parser.add_argument("--input_fn", type=str, default="Epoch20_LXRT.pth")
14+
parser.add_argument("--output_fn", type=str, default="lxmert_checkpoint_19.bin")
15+
parser.add_argument("--verbose", action="store_true", default=False)
16+
args = parser.parse_args()
17+
18+
# Load original checkpoint
19+
original_ckpt = torch.load(args.input_fn, map_location="cpu")
20+
21+
# Create corresponding VOLTA model
22+
config_file = "../config/original_lxmert.json"
23+
config = BertConfig.from_json_file(config_file)
24+
model = BertForVLPreTraining.from_pretrained("bert-base-uncased", config=config, default_gpu=True, from_hf=True)
25+
trg_dict = model.state_dict()
26+
27+
# Map original parameters onto VOLTA ones
28+
first_xlayer = config.tv_attn_sublayers[0]
29+
volta2original = dict()
30+
for k in original_ckpt.keys():
31+
ln = k.replace('module.', '')
32+
ln = ln.replace("encoder.visn_fc", "v_embeddings")
33+
ln = ln.replace("visn_fc", "image_embeddings")
34+
ln = ln.replace("visn_layer_norm", "ImgLayerNorm")
35+
ln = ln.replace("box_fc", "image_location_embeddings")
36+
ln = ln.replace("box_layer_norm", "LocLayerNorm")
37+
38+
ln = ln.replace('attention.self', 'attention_self')
39+
ln = ln.replace('attention.output', 'attention_output')
40+
if '.layer.' in ln:
41+
num = int(ln.split(".")[3])
42+
new = 2*num + ('.intermediate.' in ln or '.output.' in ln)
43+
ln = ln.replace(f".{num}.", f".{new}.")
44+
elif "r_layers" in ln:
45+
num = int(ln.split(".")[3])
46+
new = 2*num + ('.intermediate.' in ln or '.output.' in ln)
47+
ln = ln.replace(f"r_layers.{num}.", f"layer.{new}.")
48+
ln = ln.replace('.query.', '.v_query.')
49+
ln = ln.replace('.key.', '.v_key.')
50+
ln = ln.replace('.value.', '.v_value.')
51+
ln = ln.replace("dense", "v_dense")
52+
ln = ln.replace('.LayerNorm.', '.v_LayerNorm.')
53+
elif "x_layers" in ln:
54+
num = int(ln.split(".")[3])
55+
new = 3*num + first_xlayer
56+
if '.visual_attention.' in ln:
57+
ln = ln.replace(f"x_layers.{num}.visual_attention.att", f"layer.{new}.attention_self")
58+
lnv = ln.replace('.query.', '.v_query.')
59+
lnv = lnv.replace('.key.', '.v_key.')
60+
lnv = lnv.replace('.value.', '.v_value.')
61+
volta2original[lnv] = k
62+
elif '.visual_attention_output.' in ln:
63+
ln = ln.replace(f"x_layers.{num}.visual_attention_output", f"layer.{new}.attention_output")
64+
lnv = ln.replace('.dense.', '.v_dense.')
65+
lnv = lnv.replace('.LayerNorm.', '.v_LayerNorm.')
66+
volta2original[lnv] = k
67+
elif '.lang_self_att.' in ln:
68+
new += 1
69+
ln = ln.replace(f"x_layers.{num}.lang_self_att.self", f"layer.{new}.attention_self")
70+
ln = ln.replace(f"x_layers.{num}.lang_self_att.output", f"layer.{new}.attention_output")
71+
elif '.visn_self_att' in ln:
72+
new += 1
73+
ln = ln.replace(f"x_layers.{num}.visn_self_att.self", f"layer.{new}.attention_self")
74+
ln = ln.replace(f"x_layers.{num}.visn_self_att.output", f"layer.{new}.attention_output")
75+
ln = ln.replace('.query.', '.v_query.')
76+
ln = ln.replace('.key.', '.v_key.')
77+
ln = ln.replace('.value.', '.v_value.')
78+
ln = ln.replace('.dense.', '.v_dense.')
79+
ln = ln.replace('.LayerNorm.', '.v_LayerNorm.')
80+
elif '.lang_inter.' in ln:
81+
new += 2
82+
ln = ln.replace(f"x_layers.{num}.lang_inter.", f"layer.{new}.intermediate.")
83+
elif '.visn_inter.' in ln:
84+
new += 2
85+
ln = ln.replace(f"x_layers.{num}.visn_inter.", f"layer.{new}.intermediate.")
86+
ln = ln.replace('.dense.', '.v_dense.')
87+
elif '.lang_output.' in ln:
88+
new += 2
89+
ln = ln.replace(f"x_layers.{num}.lang_output.", f"layer.{new}.output.")
90+
elif '.visn_output.' in ln:
91+
new += 2
92+
ln = ln.replace(f"x_layers.{num}.visn_output.", f"layer.{new}.output.")
93+
ln = ln.replace('.LayerNorm.', '.v_LayerNorm.')
94+
ln = ln.replace('.dense.', '.v_dense.')
95+
96+
ln = ln.replace("seq_relationship", "bi_seq_relationship")
97+
ln = ln.replace("pooler", "t_pooler")
98+
ln = ln.replace("answer_head", "cls.qaPredictions")
99+
ln = ln.replace("obj_predict_head", "cls.imagePredictions")
100+
ln = ln.replace("decoder_dict.obj", "decoder_dict.3")
101+
ln = ln.replace("decoder_dict.attr", "decoder_dict.4")
102+
ln = ln.replace("decoder_dict.feat", "decoder_dict.5")
103+
104+
volta2original[ln] = k
105+
106+
# Apply mapping
107+
for trg, src in volta2original.items():
108+
if args.verbose:
109+
print(trg, '<-', src)
110+
assert trg_dict[trg].shape == original_ckpt[src].shape
111+
trg_dict[trg] = original_ckpt[src]
112+
model.load_state_dict(trg_dict)
113+
114+
# Save checkpoint of VOLTA model
115+
torch.save(model.state_dict(), args.output_fn)

volta/encoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,14 +738,14 @@ def forward(self, hidden_states):
738738

739739

740740
class LxmertAnswerHead(nn.Module):
741-
def __init__(self, config, num_answers):
741+
def __init__(self, config):
742742
super().__init__()
743743
hid_dim = config.v_hidden_size
744744
self.logit_fc = nn.Sequential(
745745
nn.Linear(hid_dim, hid_dim * 2),
746746
GeLU(),
747747
BertLayerNorm(hid_dim * 2, eps=1e-12),
748-
nn.Linear(hid_dim * 2, num_answers)
748+
nn.Linear(hid_dim * 2, config.qa_num_answers)
749749
)
750750

751751
def forward(self, hidden_states):

0 commit comments

Comments
 (0)