-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
73 lines (49 loc) · 1.93 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import sys
import logging
import datetime
from collections import OrderedDict
import cv2
import torch
import torch.distributed as dist
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
def synchronize():
if not dist.is_available():
return
if not dist.is_initialized():
return
if dist.get_world_size() == 1:
return
dist.barrier()
def add_handler(output_dir, log_name, mode="a"):
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
file_handler = logging.FileHandler(os.path.join(output_dir, log_name), mode=mode)
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
logging.getLogger().addHandler(file_handler)
logging.getLogger().addHandler(stream_handler)
def load_checkpoint(ckp_path, model, optimizer=None, scheduler=None, device="cuda"):
checkpoint = torch.load(ckp_path, map_location=device)
start_epoch = checkpoint["epoch"]
state_dict = checkpoint["state_dict"]
try:
model.load_state_dict(state_dict)
except:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace("module.", "")
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
if optimizer != None:
optimizer.load_state_dict(checkpoint["optimizer"])
if scheduler != None:
scheduler.load_state_dict(checkpoint["lr_scheduler"])
return start_epoch
def get_remaining_time(iter, epoch, epoch_iters, end, args):
passed_iter = 1 + iter + epoch_iters * (epoch - args.schedule_config["start_epoch"])
remaining_iter = args.schedule_config["train_iters"] - args.schedule_config["curr_iter"]
seconds = remaining_iter * ((end - args.schedule_config["start_time"]) / passed_iter)
remaining_time = str(datetime.timedelta(seconds=int(seconds)))
return remaining_time