-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
95 lines (75 loc) · 4.07 KB
/
train.py
File metadata and controls
95 lines (75 loc) · 4.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""This script just trains models from scratch, to later be pruned
@Author: Nathan Greffe
'"""
import argparse
import json
import torch.optim.lr_scheduler as lr_scheduler
import torch
import torch.nn as nn
import os
from models.wideresnet import *
from utils.training_fcts import *
from utils.data_fcts import *
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--workers', default=0, type=int, metavar='N', help='number of data loading workers')
parser.add_argument('--save_file', default='saveto', type=str, help='save file for checkpoints')
parser.add_argument('--data_loc', default='~/Documents/CIFAR-10', help='folder containing the CIFAR-10 dataset')
parser.add_argument('--full_train', dest='full_train', action='store_true',
help='trains with the full dataset (NetAdapt makes more sense if base model is trained with'
' training set only)')
# Learning specific arguments
parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--learning_rate', default=.1, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--lr_type', default='multistep', type=str, help='learning rate strategy (default: cosine)',
choices=['cosine', 'multistep'])
parser.add_argument('--epochs', default=200, type=int, metavar='epochs', help='no. epochs')
parser.add_argument('--epoch_step', default='[60,120,160]', type=str, help='json list with epochs to drop lr on')
parser.add_argument('--lr_decay_ratio', default=0.2, type=float, help='learning rate decay factor')
parser.add_argument('--holdout_prop', default=0.1, type=float, help='fraction of training set used for holdout')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, metavar='W', help='weight decay')
parser.add_argument('--net', choices=['res'], default='res')
# Net specific
parser.add_argument('--depth', '-d', default=40, type=int, metavar='D', help='depth of wideresnet/densenet')
parser.add_argument('--width', '-w', default=2.0, type=float, metavar='W', help='width of wideresnet')
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
if args.net == 'res':
model = WideResNet(args.depth, args.width, device)
else:
raise ValueError('pick a valid net')
error_history = []
model.to(device)
# base datasets
if args.full_train:
train_loader, val_loader = get_full_train_val(args.data_loc, args.workers, args.batch_size)
else:
train_loader, _ = get_train_holdout(args.data_loc, args.workers, args.batch_size, args.holdout_prop)
_, val_loader = get_full_train_val(args.data_loc, args.workers, args.batch_size)
if __name__ == '__main__':
filename = os.path.join('checkpoints', f'{args.save_file}.t7')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD([v for v in model.parameters() if v.requires_grad],
lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
if args.lr_type == "multistep":
epoch_step = json.loads(args.epoch_step)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=epoch_step, gamma=args.lr_decay_ratio)
elif args.lr_type == "cosine":
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
else:
raise ValueError('pick a valid learning rate type')
for epoch in range(1, args.epochs + 1):
scheduler.step()
print(f"Epoch {epoch} -- lr is {optimizer.param_groups[0]['lr']}:")
train(model, optimizer, train_loader, criterion, device)
# evaluate on validation set
error_history.append(validate(model, val_loader, criterion, device))
if epoch == args.epochs:
torch.save({
'epoch': epoch,
'state_dict': model.state_dict(),
'error_history': error_history,
}, filename)