-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathempty_model_create.py
More file actions
52 lines (39 loc) · 1.53 KB
/
empty_model_create.py
File metadata and controls
52 lines (39 loc) · 1.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from argparse import ArgumentParser
from transformers import AutoTokenizer
from common.model.const import *
from common.sys.const import EVALUATE_WEIGHT_DIR
from learner import *
CPU_FRACTION = 1.0
GPU_FRACTION = 0.5
def read_arguments():
parser = ArgumentParser()
env = parser.add_argument_group('Dataset & Evaluation')
env.set_defaults(simple=False)
model = parser.add_argument_group('Model')
model.add_argument('--encoder', '-enc', type=str, default=DEF_ENCODER)
model.add_argument('--decoder-hidden', '-decH', type=int, default=0)
model.add_argument('--decoder-intermediate', '-decI', type=int, default=0)
model.add_argument('--decoder-layer', '-decL', type=int, default=6)
model.add_argument('--decoder-head', '-decA', type=int, default=0)
return parser.parse_args()
def build_model_config(args):
return {
MDL_ENCODER: {
MDL_ENCODER: args.encoder
},
MDL_DECODER: {
MDL_D_HIDDEN: args.decoder_hidden,
MDL_D_INTER: args.decoder_intermediate,
MDL_D_LAYER: args.decoder_layer,
MDL_D_HEAD: args.decoder_head
}
}
if __name__ == '__main__':
args = read_arguments()
if not EVALUATE_WEIGHT_DIR.exists():
EVALUATE_WEIGHT_DIR.mkdir(parents=True)
model = EPT(**build_model_config(args))
model.save(str(EVALUATE_WEIGHT_DIR.absolute()))
tokenizer = AutoTokenizer.from_pretrained(args.encoder)
with Path(EVALUATE_WEIGHT_DIR, 'tokenizer.pt').open('wb') as fp:
torch.save(tokenizer, fp)