-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathevaluate.py
60 lines (41 loc) · 1.71 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from pathlib import Path
import argparse
import cv2
import numpy as np
from tqdm import tqdm
def general_dice(y_true, y_pred):
if y_true.sum() == 0:
if y_pred.sum() == 0:
return 1
else:
return 0
return dice(y_true == 1, y_pred == 1)
def general_jaccard(y_true, y_pred):
if y_true.sum() == 0:
if y_pred.sum() == 0:
return 1
else:
return 0
return jaccard(y_true == 1, y_pred == 1)
def jaccard(y_true, y_pred):
intersection = (y_true * y_pred).sum()
union = y_true.sum() + y_pred.sum() - intersection
return (intersection + 1e-15) / (union + 1e-15)
def dice(y_true, y_pred):
return (2 * (y_true * y_pred).sum() + 1e-15) / (y_true.sum() + y_pred.sum() + 1e-15)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg('--train_path', type=str, default='data/train/angyodysplasia/masks', help='path where train images with ground truth are located')
arg('--target_path', type=str, default='predictions/UNet', help='path with predictions')
args = parser.parse_args()
result_dice = []
result_jaccard = []
for file_name in tqdm(list(Path(args.train_path).glob('*'))):
y_true = (cv2.imread(str(file_name), 0) > 255 * 0.5).astype(np.uint8)
pred_file_name = Path(args.target_path) / (file_name.stem.replace('_a', '') + '.png')
y_pred = (cv2.imread(str(pred_file_name), 0) > 255 * 0.5).astype(np.uint8)
result_dice += [dice(y_true, y_pred)]
result_jaccard += [jaccard(y_true, y_pred)]
print('Dice = ', np.mean(result_dice), np.std(result_dice))
print('Jaccard = ', np.mean(result_jaccard), np.std(result_jaccard))