-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathtrain.py
More file actions
71 lines (61 loc) · 2.65 KB
/
train.py
File metadata and controls
71 lines (61 loc) · 2.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import pathlib
import torch
import tensorboardX
import params
from net import Encoder
from net import UniWaveNet
from utils import DatasetFromFolder
from trainer import UniWaveNetTrainer
parser = argparse.ArgumentParser()
parser.add_argument('--use_cuda', action='store_true', help='use cuda?')
parser.add_argument('--start_iteration', '-i', type=int, default=1,
help='Start iteraion setting for using resume')
parser.add_argument('--encoder_path', '-e', default=None,
help='Trained encoder path for using resum')
parser.add_argument('--wavenet_path', '-w', default=None,
help='Trained wavenet path for using resum')
parser.add_argument('--optimizer_path', '-o', default=None,
help='Optimizer state path for using resum')
args = parser.parse_args()
if args.use_cuda and not torch.cuda.is_available():
raise Exception('No GPU found, please run without --use_cuda')
device = torch.device('cuda' if args.use_cuda else 'cpu')
train_dataset = DatasetFromFolder(
params.root, params.dataset_type, params.sr, params.length,
params.frame_length, params.hop, params.n_mels, 'train',
params.seed)
valid_dataset = DatasetFromFolder(
params.root, params.dataset_type, params.sr, params.length,
params.frame_length, params.hop, params.n_mels, 'valid',
params.seed)
train_data_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=params.batch_size,
shuffle=True)
valid_data_loader = torch.utils.data.DataLoader(
dataset=valid_dataset,
batch_size=params.batch_size,
shuffle=False)
encoder = Encoder(
params.upscale_factors,
params.n_wavenets * params.n_layers * params.n_loops, params.r,
params.n_mels).to(device)
wavenet = UniWaveNet(
params.n_wavenets, params.n_layers, params.n_loops, params.a, params.r,
params.s).to(device)
optimizer = torch.optim.Adam(
list(wavenet.parameters()) + list(encoder.parameters()), lr=params.lr)
train_writer = tensorboardX.SummaryWriter(
str(pathlib.Path(params.output_dir, 'train')))
valid_writer = tensorboardX.SummaryWriter(
str(pathlib.Path(params.output_dir, 'valid')))
trainer = UniWaveNetTrainer(
train_data_loader, valid_data_loader, train_writer, valid_writer,
params.valid_iteration, params.save_iteration, device, encoder, wavenet,
optimizer, params.loss_weights, params.scale, params.loss_threshold,
params.sr, params.output_dir, params.gradient_threshold)
trainer.load_trained_encoder(args.encoder_path)
trainer.load_trained_wavenet(args.wavenet_path)
trainer.load_optimizer_state(args.optimizer_path)
trainer.run(params.n_iteration, args.start_iteration)