forked from zym1119/DeepLabv3_MobileNetv2_PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
executable file
·59 lines (48 loc) · 1.96 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import argparse
from utils import create_dataset
from network import MobileNetv2_DeepLabv3
from config import Params
from utils import print_config
LOG = lambda x: print('\033[0;31;2m' + x + '\033[0m')
def main():
# add argumentation
parser = argparse.ArgumentParser(description='MobileNet_v2_DeepLab_v3 Pytorch Implementation')
parser.add_argument('--dataset', default='cityscapes', choices=['cityscapes', 'other'],
help='Dataset used in training MobileNet v2+DeepLab v3')
parser.add_argument('--root', default='./data/cityscapes', help='Path to your dataset')
parser.add_argument('--epoch', default=None, help='Total number of training epoch')
parser.add_argument('--lr', default=None, help='Base learning rate')
parser.add_argument('--pretrain', default=None, help='Path to a pre-trained backbone model')
parser.add_argument('--resume_from', default=None, help='Path to a checkpoint to resume model')
args = parser.parse_args()
params = Params()
# parse args
if not os.path.exists(args.root):
if params.dataset_root is None:
raise ValueError('ERROR: Root %s not exists!' % args.root)
else:
params.dataset_root = args.root
if args.epoch is not None:
params.num_epoch = args.epoch
if args.lr is not None:
params.base_lr = args.lr
if args.pretrain is not None:
params.pre_trained_from = args.pretrain
if args.resume_from is not None:
params.resume_from = args.resume_from
LOG('Network parameters:')
print_config(params)
# create dataset and transformation
LOG('Creating Dataset and Transformation......')
datasets = create_dataset(params)
LOG('Creation Succeed.\n')
# create model
LOG('Initializing MobileNet and DeepLab......')
net = MobileNetv2_DeepLabv3(params, datasets)
LOG('Model Built.\n')
# let's start to train!
net.Train()
net.Test()
if __name__ == '__main__':
main()