-
Notifications
You must be signed in to change notification settings - Fork 72
/
Copy pathmodel.py
executable file
·98 lines (80 loc) · 4.19 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
'''
* @author [Zizhao Zhang]
* @email [[email protected]]
* @create date 2017-05-25 02:21:13
* @modify date 2017-05-25 02:21:13
* @desc [description]
'''
import tensorflow as tf
try:
from tensorflow.contrib import keras as keras
print ('load keras from tensorflow package')
except:
print ('update your tensorflow')
from tensorflow.contrib.keras import models
from tensorflow.contrib.keras import layers
class UNet():
def __init__(self):
print ('build UNet ...')
def get_crop_shape(self, target, refer):
# width, the 3rd dimension
cw = (target.get_shape()[2] - refer.get_shape()[2]).value
assert (cw >= 0)
if cw % 2 != 0:
cw1, cw2 = int(cw/2), int(cw/2) + 1
else:
cw1, cw2 = int(cw/2), int(cw/2)
# height, the 2nd dimension
ch = (target.get_shape()[1] - refer.get_shape()[1]).value
assert (ch >= 0)
if ch % 2 != 0:
ch1, ch2 = int(ch/2), int(ch/2) + 1
else:
ch1, ch2 = int(ch/2), int(ch/2)
return (ch1, ch2), (cw1, cw2)
def create_model(self, img_shape, num_class):
concat_axis = 3
inputs = layers.Input(shape = img_shape)
conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv1_1')(inputs)
conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)
up_conv5 = layers.UpSampling2D(size=(2, 2))(conv5)
ch, cw = self.get_crop_shape(conv4, up_conv5)
crop_conv4 = layers.Cropping2D(cropping=(ch,cw))(conv4)
up6 = layers.concatenate([up_conv5, crop_conv4], axis=concat_axis)
conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
up_conv6 = layers.UpSampling2D(size=(2, 2))(conv6)
ch, cw = self.get_crop_shape(conv3, up_conv6)
crop_conv3 = layers.Cropping2D(cropping=(ch,cw))(conv3)
up7 = layers.concatenate([up_conv6, crop_conv3], axis=concat_axis)
conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
up_conv7 = layers.UpSampling2D(size=(2, 2))(conv7)
ch, cw = self.get_crop_shape(conv2, up_conv7)
crop_conv2 = layers.Cropping2D(cropping=(ch,cw))(conv2)
up8 = layers.concatenate([up_conv7, crop_conv2], axis=concat_axis)
conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
up_conv8 = layers.UpSampling2D(size=(2, 2))(conv8)
ch, cw = self.get_crop_shape(conv1, up_conv8)
crop_conv1 = layers.Cropping2D(cropping=(ch,cw))(conv1)
up9 = layers.concatenate([up_conv8, crop_conv1], axis=concat_axis)
conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
ch, cw = self.get_crop_shape(inputs, conv9)
conv9 = layers.ZeroPadding2D(padding=((ch[0], ch[1]), (cw[0], cw[1])))(conv9)
conv10 = layers.Conv2D(num_class, (1, 1))(conv9)
model = models.Model(inputs=inputs, outputs=conv10)
return model