-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathmodel1.py
More file actions
89 lines (63 loc) · 1.93 KB
/
model1.py
File metadata and controls
89 lines (63 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import cv2
from keras.preprocessing import image
from keras.models import Model
from keras.layers import Input, Conv2D, Deconv2D, Activation, BatchNormalization, add
from keras.callbacks import ModelCheckpoint
from datagen import gen_data
SEED = 1
EPOCHS = 40
BATCH_SIZE = 4
LOAD_WEIGHTS = False
IMG_HEIGHT, IMG_WIDTH = 128, 128
inputs = Input((None, None, 1))
x = Conv2D(64, 9, padding='same')(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
outputs = Conv2D(1, 3, padding='same', activation='sigmoid')(x)
model = Model(inputs=inputs, outputs=outputs)
model.summary()
if LOAD_WEIGHTS:
model.load_weights('model1.h5')
model.compile(loss='MSE', optimizer='Adam')
checkpointer = ModelCheckpoint(filepath='model1.h5', verbose=1)
def _train_generator():
rnd = np.random.RandomState(SEED)
while True:
yield gen_data(rnd, BATCH_SIZE)
def _val_generator():
rnd = np.random.RandomState(SEED + 1)
while True:
yield gen_data(rnd, BATCH_SIZE)
train_generator = _train_generator()
val_generator = _val_generator()
history = model.fit_generator(
train_generator,
steps_per_epoch=512 // BATCH_SIZE,
epochs=EPOCHS,
validation_data=val_generator,
validation_steps=32 // BATCH_SIZE,
callbacks=[checkpointer]
)
model.save('model1_final.h5')