-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpredict.py
96 lines (69 loc) · 2.92 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from utils import *
from dataloader import *
from transformer import *
from search import *
def load_model():
x_cti = load_tkn_to_idx(sys.argv[2])
x_wti = load_tkn_to_idx(sys.argv[3])
y_wti = load_tkn_to_idx(sys.argv[4])
y_itw = load_idx_to_tkn(sys.argv[4])
model = transformer(x_cti, x_wti, y_wti)
print(model)
load_checkpoint(sys.argv[1], model)
return model, x_cti, x_wti, y_itw
def run_model(model, data, y_itw):
with torch.no_grad():
model.eval()
for batch in data.batchify(BATCH_SIZE * BEAM_SIZE):
xc, xw, _, lens = batch.sort()
xc, xw = data.to_tensor(xc, xw, lens, eos = True)
eos = [False for _ in xw] # EOS states
b, t = len(xw), 0
x_mask = padding_mask(xw, xw)
batch.y1 = [[] for _ in xw]
batch.prob = [0 for _ in xw]
batch.attn = [[["", *batch.x1[i], EOS]] for i in batch.idx]
batch.copy = [[["", *batch.x1[i]]] for i in batch.idx]
model.dec.M = model.enc(xc, xw, x_mask)
yi = LongTensor([[SOS_IDX]] * b)
while t < MAX_LEN and sum(eos) < len(eos):
y_mask = padding_mask(yi, yi) | lookahead_mask(yi, yi)
xy_mask = padding_mask(yi, xw)
yo = model.dec(yi, y_mask, xy_mask)[:, -1]
args = (model.dec, batch, y_itw, eos, lens, yo)
yo = beam_search(*args, t) if BEAM_SIZE > 1 else greedy_search(*args)
yi = torch.cat([yi, yo], 1)
t += 1
batch.unsort()
if VERBOSE:
for i in range(0, len(batch.y1), BEAM_SIZE):
i //= BEAM_SIZE
print("attn[%d] =" % i)
print(mat2csv(batch.attn[i]), end = "\n\n")
for i, (x0, y0, y1) in enumerate(zip(batch.x0, batch.y0, batch.y1)):
if not i % BEAM_SIZE: # use the best candidate from each beam
y1 = [y_itw[y] for y in y1[:-1]]
yield x0, y0, y1
def predict(model, x_cti, x_wti, y_itw, filename):
data = dataloader(batch_first = True)
fo = open(filename)
for line in fo:
data.append_row()
x0, y0 = line.strip(), []
if x0.count("\t") == 1:
x0, y0 = x0.split("\t")
x1 = tokenize(x0, UNIT)
xc = [[x_cti.get(c, UNK_IDX) for c in w] for w in x1]
xw = [x_wti.get(w, UNK_IDX) for w in x1]
data.append_item(x0 = x0, x1 = x1, xc = xc, xw = xw, y0 = y0)
for _ in range(BEAM_SIZE - 1):
data.clone_row()
fo.close()
return run_model(model, data, y_itw)
if __name__ == "__main__":
if len(sys.argv) != 6:
sys.exit("Usage: %s model vocab.src.char_to_idx vocab.src.word_to_idx vocab.tgt.word_to_idx test_data" % sys.argv[0])
for x, y0, y1 in predict(*load_model(), sys.argv[5]):
if y0:
print((x, y0))
print((x, y1))