-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSiamese NN by Contrastive Loss.py
251 lines (176 loc) · 8.15 KB
/
Siamese NN by Contrastive Loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
''' Siamese Neural Network with custom Loss, Contrastive Loss, using Functionaal API, Saber
'''
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense, Dropout, Lambda
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras import backend as K #for operaations/math in Keras
import numpy as np
import matplotlib.pyplot as plt
import random
#_____________________________________________________________
#Prepare the Dataset
def create_pairs(x, digit_indices):
'''Positive and negative pair creation.
Alternates between positive and negative pairs.
'''
pairs = []
labels = []
n = min([len(digit_indices[d]) for d in range(10)]) - 1 #List Comprehension
for d in range(10):
for i in range(n):
z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
pairs += [[x[z1], x[z2]]]
inc = random.randrange(1, 10)
dn = (d + inc) % 10
z1, z2 = digit_indices[d][i], digit_indices[dn][i]
pairs += [[x[z1], x[z2]]]
labels += [1, 0]
return np.array(pairs), np.array(labels)
def create_pairs_on_set(images, labels):
digit_indices = [np.where(labels == i)[0] for i in range(10)]
pairs, y = create_pairs(images, digit_indices)
y = y.astype('float32')
return pairs, y
def show_image(image):
plt.figure()
plt.imshow(image)
plt.colorbar()
plt.grid(False)
plt.show()
#__________________________________________________________
# load the dataset (Mnist dataset):
'''We can now download and prepare our train and test sets. We will also create pairs of images that will go into the multi-input model.
'''
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# prepare train and test sets
train_images = train_images.astype('float32')
test_images = test_images.astype('float32')
# normalize values
train_images = train_images / 255.0
test_images = test_images / 255.0
# →> create pairs on train and test sets
tr_pairs, tr_y = create_pairs_on_set(train_images, train_labels)
ts_pairs, ts_y = create_pairs_on_set(test_images, test_labels)
#_______________________________________________________________
'''Visualizing a sample pair of images (test)
'''
# array index
this_pair = 7
# show images at this index
show_image(ts_pairs[this_pair][0])
show_image(ts_pairs[this_pair][1])
# print the label for this pair
print(ts_y[this_pair])
# print other pairs (train)
show_image(tr_pairs[:,0][0])
show_image(tr_pairs[:,0][1])
show_image(tr_pairs[:,1][0])
show_image(tr_pairs[:,1][1])
#_______________________________________________________
''' Building the Model→► with functional API (Siamese NN), NOT Sequential API '''
def initialize_base_network():
input = Input(shape=(28,28,), name="base_input")
x = Flatten(name="flatten_input")(input)
x = Dense(128, activation='relu', name="first_base_dense")(x)
x = Dropout(0.1, name="first_dropout")(x)
x = Dense(128, activation='relu', name="second_base_dense")(x)
x = Dropout(0.1, name="second_dropout")(x)
x = Dense(128, activation='relu', name="third_base_dense")(x)
return Model(inputs=input, outputs=x)
#____________________________________________________
def euclidean_distance(vects):
x, y = vects
sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
return K.sqrt(K.maximum(sum_square, K.epsilon()))
def eucl_dist_output_shape(shapes):
shape1, shape2 = shapes
return (shape1[0], 1)
base_network = initialize_base_network()
#________________________________________________________
''' Let's now build the Siamese network '''
# create the left input and point to the base network
input_a = Input(shape=(28,28,), name="left_input")
vect_output_a = base_network(input_a)
# create the right input and point to the base network
input_b = Input(shape=(28,28,), name="right_input")
vect_output_b = base_network(input_b)
# measure the similarity of the two vector outputs
output = Lambda(euclidean_distance, name="output_layer", output_shape=eucl_dist_output_shape)([vect_output_a, vect_output_b])
# specify the inputs and output of the model
model = Model([input_a, input_b], output)
#___________________________________________
''' Training the Model with custom Loss, Contrastive Loss, Double def '''
def contrastive_loss_with_margin(margin): #Outer def, will use as hyperparameter
def contrastive_loss(y_true, y_pred): #Inner def
'''Contrastive loss from Hadsell-et-al.'06
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
'''
square_pred = K.square(y_pred)
margin_square = K.square(K.maximum(margin - y_pred, 0))
return (y_true * square_pred + (1 - y_true) * margin_square)
return contrastive_loss
rms = RMSprop() #a special GD optimizer! (Root Mean Squared Propagation)
model.compile(loss=contrastive_loss_with_margin(margin=1), optimizer=rms)
history = model.fit([tr_pairs[:,0], tr_pairs[:,1]], tr_y, epochs=20, batch_size=128, validation_data=([ts_pairs[:,0], ts_pairs[:,1]], ts_y))
#______________________________________________________________
''' Model Evaluation '''
def compute_accuracy(y_true, y_pred):
'''Compute classification accuracy with a fixed threshold on distances.
'''
#The numpy.ravel() functions returns contiguous flattened array(1D array with all the input-array elements and with the same type as it)
pred = y_pred.ravel() < 0.5
return np.mean(pred == y_true)
########################
loss = model.evaluate(x=[ts_pairs[:,0],ts_pairs[:,1]], y=ts_y)
y_pred_train = model.predict([tr_pairs[:,0], tr_pairs[:,1]])
train_accuracy = compute_accuracy(tr_y, y_pred_train)
y_pred_test = model.predict([ts_pairs[:,0], ts_pairs[:,1]])
test_accuracy = compute_accuracy(ts_y, y_pred_test)
print("Loss = {}, Train Accuracy = {} Test Accuracy = {}".format(loss, train_accuracy, test_accuracy))
#########################################
#Visualize the evaluation metrics:
def plot_metrics(metric_name, title, ylim=5):
plt.title(title)
plt.ylim(0,ylim)
plt.plot(history.history[metric_name],color='blue',label=metric_name)
plt.plot(history.history['val_' + metric_name],color='green',label='val_' + metric_name)
plot_metrics(metric_name='loss', title="Loss", ylim=0.2)
#####################################
# Matplotlib config
def visualize_images():
plt.rc('image', cmap='gray_r')
plt.rc('grid', linewidth=0)
plt.rc('xtick', top=False, bottom=False, labelsize='large')
plt.rc('ytick', left=False, right=False, labelsize='large')
plt.rc('axes', facecolor='F8F8F8', titlesize="large", edgecolor='white')
plt.rc('text', color='a8151a')
plt.rc('figure', facecolor='F0F0F0')# Matplotlib fonts
# utility to display a row of digits with their predictions
def display_images(left, right, predictions, labels, title, n):
plt.figure(figsize=(17,3))
plt.title(title)
plt.yticks([])
plt.xticks([])
plt.grid(None)
left = np.reshape(left, [n, 28, 28])
left = np.swapaxes(left, 0, 1)
left = np.reshape(left, [28, 28*n])
plt.imshow(left)
plt.figure(figsize=(17,3))
plt.yticks([])
plt.xticks([28*x+14 for x in range(n)], predictions)
for i,t in enumerate(plt.gca().xaxis.get_ticklabels()):
if predictions[i] > 0.5: t.set_color('red') # bad predictions in red
plt.grid(None)
right = np.reshape(right, [n, 28, 28])
right = np.swapaxes(right, 0, 1)
right = np.reshape(right, [28, 28*n])
plt.imshow(right)
############
''' We can see sample results for 10 pairs of items below.
'''
y_pred_train = np.squeeze(y_pred_train)
indexes = np.random.choice(len(y_pred_train), size=10)
display_images(tr_pairs[:, 0][indexes], tr_pairs[:, 1][indexes], y_pred_train[indexes], tr_y[indexes], "clothes and their dissimilarity", 10)