-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_without_pretraining.py
53 lines (38 loc) · 1.87 KB
/
test_without_pretraining.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import numpy as np
from dataloader import DataLoader
from time import time
from model import LiteralKG
import pandas as pd
from argument_test_without_pretraining import parse_args
from utils.log_utils import *
from utils.metric_utils import *
from utils.model_utils import *
def test_model(args):
device = torch.device(args.device)
# load data
data = DataLoader(args, logging)
torch.cuda.empty_cache()
# construct model & optimizer
model = LiteralKG(args, data.n_entities,
data.n_relations, data.A_in, data.num_embedding_table, data.text_embedding_table)
model = load_model(model, args.pretrain_model_path)
model.to(device)
time1 = time()
prediction_scores, metrics_dict = evaluate(model, data.test_head_dict, data.test_batch_size, data.prediction_tail_ids, device, neg_rate=args.test_neg_rate)
metrics_str = 'Running test: Total Time {:.1f}s | Accuracy [{:.4f}], Precision [{:.4f}], Recall [{:.4f}], F1 [{:.4f}]'.format(
time() - time1, metrics_dict['accuracy'], metrics_dict['precision'], metrics_dict['recall'], metrics_dict['f1'])
update_evaluation_value(args.evaluation_file, "Accuracy", args.evaluation_row, metrics_dict['accuracy'])
update_evaluation_value(args.evaluation_file, "Precision", args.evaluation_row, metrics_dict['precision'])
update_evaluation_value(args.evaluation_file, "Recall", args.evaluation_row, metrics_dict['recall'])
update_evaluation_value(args.evaluation_file, "F1", args.evaluation_row, metrics_dict['f1'])
temp_metrics_df = pd.DataFrame(data=[{"metrics": metrics_str}])
temp_metrics_df.to_csv(
args.save_dir + '/test_results.tsv', sep='\t', index=False)
np.save(args.save_dir + 'prediction_scores.npy', prediction_scores)
print(metrics_str)
def main():
args = parse_args()
test_model(args)
if __name__ == '__main__':
main()