-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain
executable file
·58 lines (44 loc) · 2.87 KB
/
train
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/python
import argparse
from keras import optimizers
from keras.callbacks import ModelCheckpoint
from keras.callbacks import TensorBoard
from keras.callbacks import LearningRateScheduler
import model
from utils import ExpDecay, WeightedCrossentropy, MapillaryGenerator, Visualization, make_parallel
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=3, help='input batch size')
parser.add_argument('--image_width', type=int, default=640, help='the input image width')
parser.add_argument('--image_height', type=int, default=320, help='the input image height')
parser.add_argument('--crop_width', type=int, default=224, help='the crop width')
parser.add_argument('--crop_height', type=int, default=224, help='the crop height')
parser.add_argument('--random_flip', type=bool, default=True, help='Do random horizontal flip during training')
parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate of the optimizer')
parser.add_argument('--decay', type=float, default=0.995, help='exponential learning rate decay (per epoch)')
parser.add_argument('--n_gpu', type=int, default=1, help='Number of GPUs to use')
parser.add_argument('--n_cpu', type=int, default=8, help='Number of CPU threads to use during data generation')
parser.add_argument('--weights', type=str, default='weights.h5', help='output path for the weights file')
opt = parser.parse_args()
print(opt)
# Batch size has to be incremented by number of GPUs (each GPU processes batch_size samples)
opt.batch_size *= opt.n_gpu
#### Train ####
# Callbacks
checkpoint = ModelCheckpoint(opt.weights, monitor='val_categorical_accuracy', mode='max', save_best_only=True, save_weights_only=True, verbose=1)
tensorboard = TensorBoard(batch_size=opt.batch_size)
visualization = Visualization(resize_shape=(opt.crop_width, opt.crop_height), batch_steps=10, n_gpu=opt.n_gpu)
lr_decay = LearningRateScheduler(ExpDecay(opt.lr, opt.decay).scheduler)
# Generators
train_generator = MapillaryGenerator(batch_size=opt.batch_size, crop_shape=(opt.crop_width, opt.crop_height), resize_shape=(opt.image_width, opt.image_height))
val_generator = MapillaryGenerator(mode='validation', batch_size=opt.batch_size, crop_shape=None, resize_shape=(opt.crop_width, opt.crop_height))
# Optimizer
optim = optimizers.SGD(lr=opt.lr, momentum=0.9)
# Model
tiramisu = model.build(opt.crop_width, opt.crop_height, 66)
tiramisu = make_parallel(tiramisu, opt.n_gpu)
# Training
tiramisu.compile(optim, WeightedCrossentropy().loss, metrics=['categorical_accuracy'])
tiramisu.fit_generator(train_generator, len(train_generator), opt.epochs, callbacks=[checkpoint, tensorboard, visualization, lr_decay],
validation_data=val_generator, validation_steps=len(val_generator), workers=opt.n_cpu)
###############