forked from postech-db-lab-starlab/LSS-Similarity
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnl_sql_dist.py
127 lines (101 loc) · 4.46 KB
/
nl_sql_dist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from xgboost import XGBClassifier
from model.neural_model import NeuralClassifier
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, roc_curve, auc
from sklearn.model_selection import train_test_split
import sys
import pickle
import json
import matplotlib.pyplot as plt
import argparse
# Measure accuracy
def measure_only_true(true_label, pred, indent):
numpy_pred = np.array(pred)
tp = ((numpy_pred == 1) * (true_label == 1)).sum()
fp = ((numpy_pred == 1) * (true_label == 0)).sum()
fn = ((numpy_pred == 0) * (true_label == 1)).sum()
tn = ((numpy_pred == 0) * (true_label == 0)).sum()
print(" True Positive: %d" % tp)
print(" False Positive: %d" % fp)
print(" False Negative: %d" % fn)
print(" True Negative: %d" % tn)
print("-----------------------------")
print(" Precision: %.2f" % (float(tp) / (tp + fp)))
print(" Recall: %.2f" % (float(tp) / (tp + fn)))
print("=============================")
def run_xgb_model(model_params, train_feature, train_label, test_feature, save_model=False, pretrained_model=''):
if pretrained_model != '':
model = pickle.load(open(pretrained_model, 'rb'))
else:
model = XGBClassifier(**model_params)
model.fit(train_feature, train_label)
return model
def run_neural_model(model_params, train_feature, train_label, test_feature, save_model=False, pretrained_model=''):
if pretrained_model != '':
model = torch.load(pretrained_model)
else:
model = NeuralClassifier(**model_params)
model.trainer(train_feature, train_label)
return model
def draw_roc_curve(pred, target):
fpr, tpr, threshold = roc_curve(target, pred)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, color='darkorange',
lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
def main(args):
# Load feature_label data made by save_feature.py
feats = np.loadtxt(args.F)
params = json.load(open(args.P))
# Parameters
model_params = params['model']
test_ratio = 1 if params['pretrained_model'] != '' else params['test_ratio']
features = feats[:, :5]
labels = feats[:, 5]
train_feature, test_feature_with_id, train_label, test_label = train_test_split(features, labels, test_size=test_ratio, random_state=7)
# Remove SQL id and NL id
train_feature = train_feature[:, 2:5]
test_feature = test_feature_with_id[:, 2:5]
if params['model_type'] == 'xgb':
model = run_xgb_model(model_params,
train_feature, train_label, test_feature,
params['pretrained_model'])
if params['save_model']:
pickle.dump(model, open('data/saved_model/XGB_MODEL.dat', "wb"))
elif params['model_type'] == 'neural':
model = run_neural_model(model_params,
train_feature, train_label, test_feature,
params['pretrained_model'])
if params['save_model']:
torch.save(model.state_dict, 'data/saved_model/NEURAL_MODEL.dat')
else:
raise Exception("Model should be 'xgb' or 'neural'")
# Make prediction
y_pred = model.predict(test_feature)
predictions = [round(value) for value in y_pred]
accuracy = accuracy_score(test_label, predictions)
print("Accuracy: %.2f%%" % (accuracy * 100))
measure_only_true(test_label, predictions, 0)
draw_roc_curve(y_pred, test_label)
test_feature_ids = test_feature_with_id[:, :2]
test_feature_answer = np.concatenate([test_feature_with_id, np.array(predictions).reshape(-1, 1)], -1)
np.save('data/test_result.dat', test_feature_answer)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--features', required=False, dest='F',
help='Path to final feature file. Defualt: "./feature_answer_all.txt"',
default='data/feature_answer_all.txt')
parser.add_argument('--parameters', required=False, dest='P',
help='Path to parameter file. Defualt: "./xgb_params.json"',
default='config/xgb_params.json')
args = parser.parse_args()
main(args)