forked from facebookresearch/InferSent
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata.py
executable file
·92 lines (73 loc) · 3.1 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
import numpy as np
import torch
def get_batch(batch, word_vec, emb_dim=300):
# sent in batch in decreasing order of lengths (bsize, max_len, word_dim)
lengths = np.array([len(x) for x in batch])
max_len = np.max(lengths)
embed = np.zeros((max_len, len(batch), emb_dim))
for i in range(len(batch)):
for j in range(len(batch[i])):
embed[j, i, :] = word_vec[batch[i][j]]
return torch.from_numpy(embed).float(), lengths
def get_word_dict(sentences):
# create vocab of words
word_dict = {}
for sent in sentences:
for word in sent.split():
if word not in word_dict:
word_dict[word] = ''
word_dict['<s>'] = ''
word_dict['</s>'] = ''
word_dict['<p>'] = ''
return word_dict
def get_glove(word_dict, glove_path):
# create word_vec with glove vectors
word_vec = {}
with open(glove_path) as f:
for line in f:
word, vec = line.split(' ', 1)
if word in word_dict:
word_vec[word] = np.array(list(map(float, vec.split())))
print('Found {0}(/{1}) words with glove vectors'.format(
len(word_vec), len(word_dict)))
return word_vec
def build_vocab(sentences, glove_path):
word_dict = get_word_dict(sentences)
word_vec = get_glove(word_dict, glove_path)
print('Vocab size : {0}'.format(len(word_vec)))
return word_vec
def get_nli(data_path):
s1 = {}
s2 = {}
target = {}
dico_label = {'entailment': 0, 'neutral': 1, 'contradiction': 2}
for data_type in ['train', 'dev', 'test']:
s1[data_type], s2[data_type], target[data_type] = {}, {}, {}
s1[data_type]['path'] = os.path.join(data_path, 's1.' + data_type)
s2[data_type]['path'] = os.path.join(data_path, 's2.' + data_type)
target[data_type]['path'] = os.path.join(data_path,
'labels.' + data_type)
s1[data_type]['sent'] = [line.rstrip() for line in
open(s1[data_type]['path'], 'r')]
s2[data_type]['sent'] = [line.rstrip() for line in
open(s2[data_type]['path'], 'r')]
target[data_type]['data'] = np.array([dico_label[line.rstrip('\n')]
for line in open(target[data_type]['path'], 'r')])
assert len(s1[data_type]['sent']) == len(s2[data_type]['sent']) == \
len(target[data_type]['data'])
print('** {0} DATA : Found {1} pairs of {2} sentences.'.format(
data_type.upper(), len(s1[data_type]['sent']), data_type))
train = {'s1': s1['train']['sent'], 's2': s2['train']['sent'],
'label': target['train']['data']}
dev = {'s1': s1['dev']['sent'], 's2': s2['dev']['sent'],
'label': target['dev']['data']}
test = {'s1': s1['test']['sent'], 's2': s2['test']['sent'],
'label': target['test']['data']}
return train, dev, test