-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.py
More file actions
90 lines (79 loc) · 3.31 KB
/
config.py
File metadata and controls
90 lines (79 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATASET_CONFIG = {
'HandsData': {
'data_dir': os.path.join(BASE_DIR, 'HandsData'),
'print_train_dir': os.path.join(BASE_DIR, 'HandsData', 'print-train'),
'vein_train_dir': os.path.join(BASE_DIR, 'HandsData', 'vein-train'),
'print_test_dir': os.path.join(BASE_DIR, 'HandsData', 'print-test'),
'vein_test_dir': os.path.join(BASE_DIR, 'HandsData', 'vein-test'),
'num_classes': 290,
'img_size': (128, 128),
'in_channels': 3,
},
'CASIA': {
'data_dir': os.path.join(BASE_DIR, 'data2', 'CASIA'),
'print_train_dir': os.path.join(BASE_DIR, 'data2', 'CASIA', 'print-train'),
'vein_train_dir': os.path.join(BASE_DIR, 'data2', 'CASIA', 'vein-train'),
'print_test_dir': os.path.join(BASE_DIR, 'data2', 'CASIA', 'print-test'),
'vein_test_dir': os.path.join(BASE_DIR, 'data2', 'CASIA', 'vein-test'),
'num_classes': 200,
'img_size': (128, 128),
'in_channels': 3,
},
'QH': {
'data_dir': os.path.join(BASE_DIR, 'data2', 'QH'),
'print_train_dir': os.path.join(BASE_DIR, 'data2', 'QH', 'print-train'),
'vein_train_dir': os.path.join(BASE_DIR, 'data2', 'QH', 'vein-train'),
'print_test_dir': os.path.join(BASE_DIR, 'data2', 'QH', 'print-test'),
'vein_test_dir': os.path.join(BASE_DIR, 'data2', 'QH', 'vein-test'),
'num_classes': 500,
'img_size': (128, 128),
'in_channels': 3,
},
'TJ': {
'data_dir': os.path.join(BASE_DIR, 'data2', 'TJ'),
'print_train_dir': os.path.join(BASE_DIR, 'data2', 'TJ', 'print-train'),
'vein_train_dir': os.path.join(BASE_DIR, 'data2', 'TJ', 'vein-train'),
'print_test_dir': os.path.join(BASE_DIR, 'data2', 'TJ', 'print-test'),
'vein_test_dir': os.path.join(BASE_DIR, 'data2', 'TJ', 'vein-test'),
'num_classes': 600,
'img_size': (128, 128),
'in_channels': 3,
},
'CUMT2': {
'data_dir': os.path.join(BASE_DIR, 'data2', 'CUMT2'),
'print_train_dir': os.path.join(BASE_DIR, 'data2', 'CUMT2', 'print_train'),
'vein_train_dir': os.path.join(BASE_DIR, 'data2', 'CUMT2', 'vein_train'),
'print_test_dir': os.path.join(BASE_DIR, 'data2', 'CUMT2', 'print_test'),
'vein_test_dir': os.path.join(BASE_DIR, 'data2', 'CUMT2', 'vein_test'),
'num_classes': 532,
'img_size': (128, 128),
'in_channels': 3,
},
}
DEFAULT_DATASET = 'HandsData'
BATCH_SIZE = 32
NUM_WORKERS = 4
LEARNING_RATE = 1e-4
NUM_EPOCHS = 100
WEIGHT_DECAY = 1e-4
SEED = 42
DETERMINISTIC = True
CUDNN_BENCHMARK = False
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]
FEATURE_DIM = 256
NUM_EXPERTS = 3
OUT_STAGES = [3, 4, 5]
REDUCER_CHANNELS = 64
LOAD_BALANCE_WEIGHT = 0.01
def get_dataset_config(dataset_name):
if dataset_name not in DATASET_CONFIG:
raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(DATASET_CONFIG.keys())}")
return DATASET_CONFIG[dataset_name]
def get_save_dir(dataset_name):
return os.path.join(BASE_DIR, 'checkpoints4', dataset_name)
#checkpoints2表示的是骨干提取网络进行了三层提取
#checkpoints3表示的是moe添加了负载平衡损
#checkpoints4表示数据预处理 + 负载均衡损失反传