-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtrain.py
More file actions
109 lines (96 loc) · 3.07 KB
/
train.py
File metadata and controls
109 lines (96 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# !/usr/bin/env python
# -*-coding:utf-8 -*-
"""
# File : t.py
# Time :2023/9/15 10:18
# Author :yujia
# version :python 3.6
# Description:
"""
import random
import torch.backends.cudnn as cudnn
import torch
import numpy as np
import os
from tqdm import tqdm
from torch.utils.data import DataLoader
from utils import utils
from tool import dataloader
from utils import utils_lr
from tool import load, process
class Opt():
trainRoot = r"data"
cuda = True
pretrained = ''
alphabet_path = 'tool/charactes_keys.txt'
expr_dir = 'expr'
nepoch = 100
batchSize = 64
nh = 256
nc = 3
workers = 0
imgH = 32
imgW = 100
lr = 0.001
beta1 = 0.5
optimizer_type = "Adam"
model_name = "crnnlite"
manualSeed = 1234
train_ratio = 0.9
opt = Opt()
if not os.path.exists(opt.expr_dir):
os.makedirs(opt.expr_dir)
if opt.cuda:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device('cpu')
# 随机种子
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
cudnn.benchmark = True
# 训练集
alphabet = dataloader.get_charactes_keys(opt.alphabet_path)
lines, labels = dataloader.load_dataset(opt.trainRoot)
train_lines, train_labels, val_lines, val_labels = dataloader.ratio_dataloader(lines, labels, opt.train_ratio, opt.batchSize)
train_dataset = dataloader.CaptchaDataset([train_lines, train_labels], [opt.imgH, opt.imgW], opt.nc)
sampler = None
train_loader = DataLoader(
train_dataset, batch_size=opt.batchSize,
shuffle=True, sampler=sampler,
num_workers=int(opt.workers),
)
test_dataset = dataloader.CaptchaDataset([val_lines, val_labels], [opt.imgH, opt.imgW], opt.nc)
test_loader = DataLoader(
test_dataset, batch_size=opt.batchSize,
shuffle=True, sampler=sampler,
num_workers=int(opt.workers),
)
# 加载模型
model = load.load_model(opt, alphabet, opt.model_name)
# 加载学习器
optimizer = load.load_optimizer(opt, model)
lr_scheduler_func = utils_lr.get_lr_scheduler_func(opt.lr, opt.optimizer_type, opt.batchSize, opt.nepoch)
# loss
criterion = torch.nn.CTCLoss()
# 解码器
converter = utils.strLabelConverter(alphabet)
acc = 0
for epoch in range(1, opt.nepoch + 1):
# 每代修改学习率
utils_lr.set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
# 训练
num_iterations = len(train_loader)
pbar = tqdm(total=num_iterations, desc=f'Train Epoch {epoch}/{opt.nepoch}', postfix=dict, mininterval=0.3)
process.fit_epoch(train_loader, model, criterion, optimizer, converter, device, pbar)
pbar.close()
# 验证
num_val = len(test_loader)
pbar = tqdm(total=num_val, desc=f'Validation Epoch {epoch}/{opt.nepoch}', postfix=dict, mininterval=0.3)
val_acc = process.val(test_loader, model, criterion, converter, device, pbar)
pbar.close()
# 保存模型
torch.save(model.state_dict(), os.path.join(opt.expr_dir, f"model_{epoch}.pth"))
if val_acc >= acc:
torch.save(model.state_dict(), os.path.join(opt.expr_dir, f"best.pth"))
acc = val_acc