|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import argparse |
| 4 | +sys.path.append("../") |
| 5 | + |
| 6 | +import torch |
| 7 | +from volta.config import BertConfig |
| 8 | +from volta.encoders import BertForVLPreTraining |
| 9 | + |
| 10 | + |
| 11 | +# Inputs |
| 12 | +parser = argparse.ArgumentParser() |
| 13 | +parser.add_argument("--input_fn", type=str, default="Epoch20_LXRT.pth") |
| 14 | +parser.add_argument("--output_fn", type=str, default="lxmert_checkpoint_19.bin") |
| 15 | +parser.add_argument("--verbose", action="store_true", default=False) |
| 16 | +args = parser.parse_args() |
| 17 | + |
| 18 | +# Load original checkpoint |
| 19 | +original_ckpt = torch.load(args.input_fn, map_location="cpu") |
| 20 | + |
| 21 | +# Create corresponding VOLTA model |
| 22 | +config_file = "../config/original_lxmert.json" |
| 23 | +config = BertConfig.from_json_file(config_file) |
| 24 | +model = BertForVLPreTraining.from_pretrained("bert-base-uncased", config=config, default_gpu=True, from_hf=True) |
| 25 | +trg_dict = model.state_dict() |
| 26 | + |
| 27 | +# Map original parameters onto VOLTA ones |
| 28 | +first_xlayer = config.tv_attn_sublayers[0] |
| 29 | +volta2original = dict() |
| 30 | +for k in original_ckpt.keys(): |
| 31 | + ln = k.replace('module.', '') |
| 32 | + ln = ln.replace("encoder.visn_fc", "v_embeddings") |
| 33 | + ln = ln.replace("visn_fc", "image_embeddings") |
| 34 | + ln = ln.replace("visn_layer_norm", "ImgLayerNorm") |
| 35 | + ln = ln.replace("box_fc", "image_location_embeddings") |
| 36 | + ln = ln.replace("box_layer_norm", "LocLayerNorm") |
| 37 | + |
| 38 | + ln = ln.replace('attention.self', 'attention_self') |
| 39 | + ln = ln.replace('attention.output', 'attention_output') |
| 40 | + if '.layer.' in ln: |
| 41 | + num = int(ln.split(".")[3]) |
| 42 | + new = 2*num + ('.intermediate.' in ln or '.output.' in ln) |
| 43 | + ln = ln.replace(f".{num}.", f".{new}.") |
| 44 | + elif "r_layers" in ln: |
| 45 | + num = int(ln.split(".")[3]) |
| 46 | + new = 2*num + ('.intermediate.' in ln or '.output.' in ln) |
| 47 | + ln = ln.replace(f"r_layers.{num}.", f"layer.{new}.") |
| 48 | + ln = ln.replace('.query.', '.v_query.') |
| 49 | + ln = ln.replace('.key.', '.v_key.') |
| 50 | + ln = ln.replace('.value.', '.v_value.') |
| 51 | + ln = ln.replace("dense", "v_dense") |
| 52 | + ln = ln.replace('.LayerNorm.', '.v_LayerNorm.') |
| 53 | + elif "x_layers" in ln: |
| 54 | + num = int(ln.split(".")[3]) |
| 55 | + new = 3*num + first_xlayer |
| 56 | + if '.visual_attention.' in ln: |
| 57 | + ln = ln.replace(f"x_layers.{num}.visual_attention.att", f"layer.{new}.attention_self") |
| 58 | + lnv = ln.replace('.query.', '.v_query.') |
| 59 | + lnv = lnv.replace('.key.', '.v_key.') |
| 60 | + lnv = lnv.replace('.value.', '.v_value.') |
| 61 | + volta2original[lnv] = k |
| 62 | + elif '.visual_attention_output.' in ln: |
| 63 | + ln = ln.replace(f"x_layers.{num}.visual_attention_output", f"layer.{new}.attention_output") |
| 64 | + lnv = ln.replace('.dense.', '.v_dense.') |
| 65 | + lnv = lnv.replace('.LayerNorm.', '.v_LayerNorm.') |
| 66 | + volta2original[lnv] = k |
| 67 | + elif '.lang_self_att.' in ln: |
| 68 | + new += 1 |
| 69 | + ln = ln.replace(f"x_layers.{num}.lang_self_att.self", f"layer.{new}.attention_self") |
| 70 | + ln = ln.replace(f"x_layers.{num}.lang_self_att.output", f"layer.{new}.attention_output") |
| 71 | + elif '.visn_self_att' in ln: |
| 72 | + new += 1 |
| 73 | + ln = ln.replace(f"x_layers.{num}.visn_self_att.self", f"layer.{new}.attention_self") |
| 74 | + ln = ln.replace(f"x_layers.{num}.visn_self_att.output", f"layer.{new}.attention_output") |
| 75 | + ln = ln.replace('.query.', '.v_query.') |
| 76 | + ln = ln.replace('.key.', '.v_key.') |
| 77 | + ln = ln.replace('.value.', '.v_value.') |
| 78 | + ln = ln.replace('.dense.', '.v_dense.') |
| 79 | + ln = ln.replace('.LayerNorm.', '.v_LayerNorm.') |
| 80 | + elif '.lang_inter.' in ln: |
| 81 | + new += 2 |
| 82 | + ln = ln.replace(f"x_layers.{num}.lang_inter.", f"layer.{new}.intermediate.") |
| 83 | + elif '.visn_inter.' in ln: |
| 84 | + new += 2 |
| 85 | + ln = ln.replace(f"x_layers.{num}.visn_inter.", f"layer.{new}.intermediate.") |
| 86 | + ln = ln.replace('.dense.', '.v_dense.') |
| 87 | + elif '.lang_output.' in ln: |
| 88 | + new += 2 |
| 89 | + ln = ln.replace(f"x_layers.{num}.lang_output.", f"layer.{new}.output.") |
| 90 | + elif '.visn_output.' in ln: |
| 91 | + new += 2 |
| 92 | + ln = ln.replace(f"x_layers.{num}.visn_output.", f"layer.{new}.output.") |
| 93 | + ln = ln.replace('.LayerNorm.', '.v_LayerNorm.') |
| 94 | + ln = ln.replace('.dense.', '.v_dense.') |
| 95 | + |
| 96 | + ln = ln.replace("seq_relationship", "bi_seq_relationship") |
| 97 | + ln = ln.replace("pooler", "t_pooler") |
| 98 | + ln = ln.replace("answer_head", "cls.qaPredictions") |
| 99 | + ln = ln.replace("obj_predict_head", "cls.imagePredictions") |
| 100 | + ln = ln.replace("decoder_dict.obj", "decoder_dict.3") |
| 101 | + ln = ln.replace("decoder_dict.attr", "decoder_dict.4") |
| 102 | + ln = ln.replace("decoder_dict.feat", "decoder_dict.5") |
| 103 | + |
| 104 | + volta2original[ln] = k |
| 105 | + |
| 106 | +# Apply mapping |
| 107 | +for trg, src in volta2original.items(): |
| 108 | + if args.verbose: |
| 109 | + print(trg, '<-', src) |
| 110 | + assert trg_dict[trg].shape == original_ckpt[src].shape |
| 111 | + trg_dict[trg] = original_ckpt[src] |
| 112 | +model.load_state_dict(trg_dict) |
| 113 | + |
| 114 | +# Save checkpoint of VOLTA model |
| 115 | +torch.save(model.state_dict(), args.output_fn) |
0 commit comments