forked from Gastron/sb-fin-parl-2015-2020-kevat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest-lfmmi-with-padding.py
78 lines (70 loc) · 2.99 KB
/
test-lfmmi-with-padding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env/python3
"""Finnish Parliament ASR"""
import os
import sys
import torch
import logging
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
import kaldi_io
import tqdm
from types import SimpleNamespace
def setup(hparams, run_opts):
""" Kind of mimics what Brain does """
if "device" in run_opts:
device = run_opts["device"]
elif "device" in hparams:
device = hparams["device"]
else:
device = "cpu"
print("Device is:", device)
if "cuda" in device:
torch.cuda.set_device(int(device[-1]))
modules = torch.nn.ModuleDict(hparams["modules"]).to(device)
hparams = SimpleNamespace(**hparams)
if hasattr(hparams, "checkpointer"):
if hasattr(hparams, "test_max_key"):
ckpt = hparams.checkpointer.find_checkpoint(max_key=hparams.test_max_key)
elif hasattr(hparams, "test_min_key"):
ckpt = hparams.checkpointer.find_checkpoint(min_key=hparams.test_min_key)
else:
ckpt = hparams.checkpointer.find_checkpoint()
hparams.checkpointer.load_checkpoint(ckpt)
epoch = hparams.epoch_counter.current
print("Loaded checkpoint from epoch", epoch, "at path", ckpt.path)
return modules, hparams, device
def count_scp_lines(scpfile):
lines = 0
with open(scpfile) as fin:
for _ in fin:
lines += 1
return lines
def run_test(modules, hparams, device):
num_utts = count_scp_lines(hparams.test_feats)
with open(hparams.test_probs_out, 'wb') as fo:
with torch.no_grad():
for uttid, feats in tqdm.tqdm(kaldi_io.read_mat_scp(hparams.test_feats), total=num_utts):
feats = torch.from_numpy(feats).to(device)
normalized = modules.normalize(feats.unsqueeze(0), lengths=torch.tensor([2.]), epoch=1000).squeeze(0)
padded = torch.cat(
(
normalized[0].unsqueeze(0).repeat_interleave(hparams.contextlen,dim=0),
normalized,
normalized[-1].unsqueeze(0).repeat_interleave(hparams.contextlen,dim=0)
)
)
padded = padded.unsqueeze(0)
first_relevant = int(hparams.contextlen / hparams.subsampling)
last_relevant = first_relevant + int(feats.shape[0] / hparams.subsampling)
encoded_all = modules.encoder(padded)
encoded_relevant = encoded_all[:,first_relevant:last_relevant,:]
out = modules.lin_out(encoded_all)
kaldi_io.write_mat(fo, out.squeeze(0).cpu().numpy(), key=uttid)
if __name__ == "__main__":
# Reading command line arguments
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
# Load hyperparameters file with command-line overrides
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)
modules, hparams, device = setup(hparams, run_opts)
run_test(modules, hparams, device)