-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathevaluate_segments.py
66 lines (52 loc) · 2.9 KB
/
evaluate_segments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from pathlib import Path
import numpy as np
from matplotlib import pyplot as plt
import scipy.stats as st
import seaborn as sns
import argparse
#parser = argparse.ArgumentParser()
#parser.add_argument("data_path")
#parser.add_argument("eval_path")
#args = parser.parse_args()
result = Path("result")
result.mkdir(exist_ok=True)
models = [
"aeneas_%s",
"maus_%s",
"gentle_%s",
"tedlium_%s_lstm",
"tedlium_%s_transformer",
"tedlium_%s_transformer_tedlium_16",
"tedlium_%s_transformer_tedlium_8",
"tedlium_%s_transformer_tedlium_4",
"tedlium_%s_transformer_tedlium_2",
"tedlium_%s_transformer_tedlium_1"
]
models = ["tedlium_%s_lstm","tedlium_%s_transformer_tedlium_2","tedlium_%s_transformer_libri_2","gentle_%s","aeneas_%s","maus_%s"]#["tedlium_%s_transformer_tedlium_2","gentle_%s"]
#models = ["tedlium_%s_lstm","gentle_%s"]#["tedlium_%s_transformer_tedlium_2","gentle_%s"]
models=["gentle_%s"]
with open(result / "metrics", "w") as o:
o.write(" mean median std <1s <0.5s\n")
for data_pack in [("test_augm", "dev_augm"),("test", "dev")]:#, "test"
for model in models:
all_diffs = []
for data in data_pack:
eval_path = Path("eval") / (model % data)
data_path = Path("data") / ("tedlium_%s" % data)
for path_wav in data_path.glob("*.wav"):
with open(str(eval_path / path_wav.name.replace(".wav", ".txt")), "r") as f:
text = [t.split() for t in f.readlines()[1:]]
segments = [(float(t[0]), float(t[1]), " ".join(t[4:])) for t in text]
with open(str(data_path / path_wav.name.replace(".wav", "_ground_truth.txt")), "r") as f:
text = [t.split() for t in f.readlines()]
ground_truth = [(float(t[0]), float(t[1]), " ".join(t[2:])) for t in text]
diffs = []
for i in range(len(segments)):
if segments[i][2] != ground_truth[i][2]:
raise Exception("Text not matching: " + segments[i][2] + " <=> " + ground_truth[i][2])
diffs.append(abs(segments[i][0] - ground_truth[i][0]))
diffs.append(abs(segments[i][1] - ground_truth[i][1]))
diffs = np.array(diffs)
all_diffs.append(diffs)
all_diffs = np.concatenate(all_diffs, 0)
o.write(data_pack[0] + " " + model + " " + ("%.3f" % all_diffs.mean()) + " " + ("%.3f" % np.median(all_diffs)) + " " + ("%.3f" % all_diffs.std()) + " " + ("%.3f" % (all_diffs < 1).mean()) + " " + ("%.3f" % (all_diffs < 0.5).mean()) + "\n")