-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
104 lines (73 loc) · 4.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import tensorflow as tf
import load_data as ldata
import model_v2 as model
import os
from model_v2 import BATCH_SIZE
EPOCH_SIZE = 1000
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def init_random(shape):
return np.random.uniform(0.0, 1.0, shape)
def next_batch(x, y, batch_size=BATCH_SIZE):
i = 0
while(i < len(x)):
yield x[i:i + batch_size], y[i:i + batch_size]
i = i + batch_size
if __name__ == '__main__':
sess = tf.Session()
gan = model.GAN()
train_writer = tf.summary.FileWriter('here', sess.graph)
images, labels = ldata.load_SVHN()
# print np.max(images), np.min(images)
# images = images / 255.
# ldata.cv2_save(n=10, m=10, data=images[0:100], file_path="meow.png")
images = (images - 0.5) * 2.
ldata.cv2_save(n=10, m=10, data=(images[0:100] + 1) / 2., file_path="meow.png")
label_ = tf.placeholder(tf.int64, [None])
dis_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_, logits=gan.dis))
gen_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_, logits=gan.dis_gen))
# dis_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label_, logits=gan.dis))
# gen_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=label_, logits=gan.dis_gen))
dis_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='gan/dis')
gen_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='gan/gen')
print dis_vars, gen_vars
dis_train_step = tf.train.AdamOptimizer(0.0001).minimize(dis_loss, var_list=dis_vars)
gen_train_step = tf.train.AdamOptimizer(0.0001).minimize(gen_loss, var_list=gen_vars)
# dis_train_step = tf.train.MomentumOptimizer(0.0002, 0.5).minimize(dis_loss, var_list=dis_var)
# gen_train_step = tf.train.MomentumOptimizer(0.0002, 0.5).minimize(gen_loss, var_list=gen_var)
correct_prediction = tf.equal(label_, tf.argmax(gan.dis, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# print dis_train_step
# print gen_train_step
sess.run(tf.global_variables_initializer())
print tf.GraphKeys.TRAINABLE_VARIABLES
print gen_train_step
for step in range(EPOCH_SIZE):
batch_step = 0
for x, _ in next_batch(images,labels):
batch_step = batch_step + 1
if len(x) < BATCH_SIZE: break
input_noise = init_random((BATCH_SIZE, 2048))
xn = gan.gen.eval(session=sess, feed_dict={gan.raw_input_noise:input_noise})
yn = np.array([0] * BATCH_SIZE)
x_ = np.reshape(x, (BATCH_SIZE, 32 * 32 * 3))
y = np.array([1] * BATCH_SIZE)
# print xn
rindex = [i for i in range(2 * BATCH_SIZE)]
np.random.shuffle(rindex)
tx = np.concatenate((xn, x_))[rindex]
ty = np.concatenate((yn, y))[rindex]
train_accuracy = accuracy.eval(session=sess, feed_dict={gan.raw_input_image:tx[0:BATCH_SIZE], label_:ty[0:BATCH_SIZE]})
print("step training accuracy %g" % (train_accuracy))
#ldata.cv2_save(n=10, m=10, data=(tx[0:100] + 1) / 2., file_path="meow.png")
if batch_step % 400 == 0:
ldata.cv2_save(n=16, m=16, data=(tx + 1.) / 2., file_path="gen/{}-{}.png".format(step, batch_step))
sess.run(dis_train_step, feed_dict={gan.raw_input_image:tx[0:BATCH_SIZE], label_:ty[0:BATCH_SIZE]})
# sess.run(dis_train_step, feed_dict={gan.raw_input_image:xn, label_:yn})
if batch_step % 40 == 0:
print("Epoch %d, Batch: %d" % (step, batch_step))
input_noise = init_random((BATCH_SIZE, 2048))
y = np.array([1] * BATCH_SIZE)
sess.run(gen_train_step, feed_dict={gan.raw_input_noise:input_noise, label_:y})
data = gan.gen.eval(session=sess, feed_dict={gan.raw_input_noise:init_random((BATCH_SIZE, 2048))})[0:100]
ldata.cv2_save(n=10, m=10, data=(data + 1.) / 2., file_path="gen/{}.png".format(step))