This repository was archived by the owner on Jan 23, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
67 lines (60 loc) · 2.59 KB
/
train.py
File metadata and controls
67 lines (60 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import p2p
from p2p.config import args
from p2p.utils import createOptim
losses = ['MSE', 'VGGLoss']
if __name__ == '__main__':
#
root_dir = args['root_dir']
log_dir = '{}/{}'.format(root_dir, args['log_dir'])
p2p_df = pd.read_csv('{}/{}'.format(root_dir, args['meta_data']))
p2p_df = p2p_df[p2p_df['n_frames'] == 10]
train = p2p_df[p2p_df['set'] == 'train']
test = p2p_df[p2p_df['set'] == 'test']
#
# imagenet_norm = {'mean': [0.485, 0.456, 0.406],
# 'std': [0.229, 0.224, 0.225]}
frame_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=args['norm_mean'],
std=args['norm_std'])
])
data_path = '{}/{}'.format(root_dir, args['data_dir'])
trainset = p2p.P2PDataset(
df=train, transform=frame_transform, data_path=data_path)
trainloader = DataLoader(trainset, batch_size=args['batch_size'],
shuffle=True, num_workers=args['batch_size'])
testset = p2p.P2PDataset(
df=test, transform=frame_transform, data_path=data_path)
testloader = DataLoader(testset, batch_size=args['batch_size'],
shuffle=True, num_workers=args['batch_size'])
#
chpt_path = '{}/{}'.format(root_dir, args['log_dir'])
netG = p2p.GlobalGenerator(input_nc=6, output_nc=3)
if args['face_id']:
faceid = p2p.VoxFaceID(
pretrain_path='{}/p2p/pretrained/pretrained_VoxFaceID.pth.tar'.format(root_dir))
else:
faceid = None
device_ids = list(args['gpu'])
if len(args['gpu']) > 1:
print('Data Parallel on {}'.format(args['gpu']))
netG = nn.DataParallel(netG, device_ids=device_ids).to(device_ids[0])
if args['face_id']:
faceid = nn.DataParallel(
faceid, device_ids=device_ids).to(device_ids[0])
assert args['loss'] in losses, 'Missing loss implementation'
if args['loss'] == 'MSE':
loss = nn.MSELoss()
elif args['loss'] == 'VGGLoss':
loss = p2p.VGGLoss(device_ids[0], args['use_mse'])
parameters = list(netG.parameters())
optimizer, scheduler = createOptim(parameters=parameters, lr=0.001)
p2p.train_p2p(generator=netG, faceid=faceid, trainloader=trainloader,
testloader=testloader, optim=optimizer, scheduler=scheduler,
criterion=loss, n_epochs=args['n_epochs'],
e_saves=args['e_saves'], save_path=log_dir, device_ids=device_ids)