forked from THGLab/CSpred
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCSpred.py
executable file
·208 lines (185 loc) · 10.7 KB
/
CSpred.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#!/usr/bin/env python3
# Script for making predictions using both the X module and the Y module
# Author: Jie Li
# Date created: Sep 20, 2019
from spartap_features import PDB_SPARTAp_DataReader
from data_prep_functions import *
import ucbshifty
import joblib
import toolbox
import os
import pandas as pd
import argparse
import multiprocessing
# import warnings
import sys
if sys.version_info.major < 3 or sys.version_info.major == 3 and sys.version_info.minor < 5:
raise ValueError("Python >= 3.5 required")
# Suppress Setting With Copy warnings
pd.options.mode.chained_assignment = None
SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__))
ML_MODEL_PATH = SCRIPT_PATH + "/models/"
def build_input(pdb_file_name, pH=5, rcfeats=True, hse=True, hbrad=[5.0] * 3):
'''
Function for building dataframe for the specified pdb file.
Returns a pandas dataframe for a specific single-chain pdb file.
args:
pdb_file_name = The path to the PDB file for prediction (str)
pH = pH value to be considered
rcfeats = Include a column for random coil chemical shifts (Bool)
hse = Include a feature column for the half-sphere exposure (Bool)
hbrad = Max length of hydrogen bonds for HA,HN and O (list of float)
'''
# Build feature reader and read all features
feature_reader = PDB_SPARTAp_DataReader()
pdb_data = feature_reader.df_from_file_3res(pdb_file_name,rcshifts=rcfeats, hse=hse, first_chain_only=False, sequence_columns=0,hbrad=hbrad)
pdb_data["pH"]=pH
return pdb_data
def data_preprocessing(data):
'''
Function for executing all the preprocessing steps based on the original extracted features, including fixing HA2/HA3 ring current ambiguity, adding hydrophobicity, powering features, drop unnecessary columns, etc.
'''
data = data.copy()
data = data[sorted(data.columns)]
data = ha23ambigfix(data, mode=0)
Add_res_spec_feats(data, include_onehot=False)
data = feat_pwr(data, hbondd_cols + cos_cols, [2])
data = feat_pwr(data, hbondd_cols, [-1,-2,-3])
dropped_cols = dssp_pp_cols + dssp_energy_cols + ['Unnamed: 0', 'Unnamed: 0.1', 'Unnamed: 0.1.1', 'FILE_ID', 'PDB_FILE_NAME', 'RESNAME', 'RES_NUM',"RES", 'CHAIN', 'RESNAME_ip1', 'RESNAME_im1', 'BMRB_RES_NUM', 'CG', 'RCI_S2', 'MATCHED_BMRB',"identifier"]+ rcoil_cols
data = data.drop(set(dropped_cols) & set(data.columns), axis=1)
return data
def prepare_data_for_atom(data,atom):
'''
Function to generate features data for a given atom type: meaning that the irrelevant ring current values are removed from features
args:
data - the dataset that contains all the features (pandas.DataFrame)
atom - the atom to keep ring currents
returns:
pandas.DataFrame containing the cleaned feature set
'''
dat = data.copy()
ring_col = atom + '_RC'
rem1 = ring_cols.copy()
rem1.remove(ring_col)
dat = dat.drop(rem1, axis=1)
dat[ring_col] = dat[ring_col].fillna(value=0)
return dat
def calc_sing_pdb(pdb_file_name,pH=5,TP=True,TP_pred=None,ML=True,test=False):
'''
Function for calculating chemical shifts for a single PDB file using X module / Y module / both
args:
pdb_file_name = The path to the PDB file for prediction (str)
pH = pH value to be considered
TP = Whether or not use TP module (Bool)
TP_pred = predicted shifts dataframe from Y module. If None, generate this prediction within this function (pandas.DataFrame / None)
ML = Whether or not use ML module (Bool)
test = Whether or not use test mode (Exclude mode for SHIFTY++, Bool)
'''
if not os.path.isdir(ML_MODEL_PATH):
raise ValueError("models not installed in {}".format(ML_MODEL_PATH))
if pH < 2 or pH > 12:
print("Warning! Predictions for proteins in extreme pH conditions are likely to be erroneous. Take prediction results at your own risk!")
preds = pd.DataFrame()
if TP:
if TP_pred is None:
print("Calculating UCBShift-Y predictions ...")
TP_pred = ucbshifty.main(pdb_file_name, 1, exclude=test)
if not ML:
# Prepare data when only doing TP prediction
preds = TP_pred[["RESNUM","RESNAME"]]
for atom in toolbox.ATOMS:
if atom+"_RC" in TP_pred.columns:
rc = TP_pred[atom+"_RC"]
else:
rc = 0
preds[atom+"_Y"] = TP_pred[atom] + rc
if ML:
print("Generating features ...")
feats = build_input(pdb_file_name, pH)
feats.rename(index=str, columns=sparta_rename_map, inplace=True) # Rename columns so that random coil columns can be correctly recognized
resnames = feats["RESNAME"]
resnums = feats["RES_NUM"]
rcoils = feats[rcoil_cols]
feats = data_preprocessing(feats)
result = {"RESNUM":resnums, "RESNAME":resnames}
for atom in toolbox.ATOMS:
print("Calculating UCBShift-X predictions for %s ..." % atom)
# Predictions for each atom
atom_feats = prepare_data_for_atom(feats, atom)
r0 = joblib.load(ML_MODEL_PATH + "%s_R0.sav" % atom)
r0_pred = r0.predict(atom_feats.values)
feats_r1 = atom_feats.copy()
feats_r1["R0_PRED"] = r0_pred
r1 = joblib.load(ML_MODEL_PATH + "%s_R1.sav" % atom)
r1_pred = r1.predict(feats_r1.values)
# Write ML predictions
result[atom+"_X"] = r1_pred + rcoils["RCOIL_"+atom]
if TP:
print("Calculating UCBShift predictions for %s ..." % atom)
feats_r2 = atom_feats.copy()
feats_r2["RESNAME"] = resnames
feats_r2["RESNUM"] = resnums
tp_atom = TP_pred[["RESNAME", "RESNUM", atom, atom+"_BEST_REF_SCORE", atom+"_BEST_REF_COV", atom+"_BEST_REF_MATCH"]]
feats_r2 = pd.merge(feats_r2, tp_atom, on="RESNUM", suffixes=("","_TP"), how="left")
# Write TP predictions
result[atom+"_Y"] = feats_r2[atom].values
result[atom+"_BEST_REF_SCORE"] = feats_r2[atom+"_BEST_REF_SCORE"].values
result[atom+"_BEST_REF_COV"] = feats_r2[atom+"_BEST_REF_COV"].values
result[atom+"_BEST_REF_MATCH"] = feats_r2[atom+"_BEST_REF_MATCH"].values
valid = (feats_r2.RESNAME == feats_r2.RESNAME_TP) & (feats_r2[atom].notnull())
# Subtract random coils to make secondary TP shifts
feats_r2[atom] -= rcoils["RCOIL_"+atom].values
feats_r2["R0_PRED"] = r0_pred
valid_feats_r2 = feats_r2.drop(["RESNAME","RESNUM","RESNAME_TP"], axis=1)[valid]
r2_pred = r1_pred.copy()
if len(valid_feats_r2):
r2 = joblib.load(ML_MODEL_PATH + "%s_R2.sav" % atom)
r2_pred_valid = r2.predict(valid_feats_r2.values)
r2_pred[valid] = r2_pred_valid
# Write final predictions
result[atom+"_UCBShift"] = r2_pred + rcoils["RCOIL_"+atom]
preds = pd.DataFrame(result)
return preds
if __name__ == "__main__":
args = argparse.ArgumentParser(description='This program is an NMR chemical shift predictor for protein chemical shifts (including H, Hα, C\', Cα, Cβ and N) in aqueous solution. It has two sub-modules: one is the machine learning (X) module. It uses ensemble tree-based methods to predict chemical shifts from the features extracted from the PDB files. The second sub-module is the transfer prediction (Y) module that predicts chemical shifts by "transfering" shifts from similar proteins in the database to the query protein through structure and sequence alignments. Finally, the two parts are combined to give the UCBShift predictions.')
args.add_argument("input", help="The query PDB file or list of PDB files for which the shifts are calculated")
args.add_argument("--batch", "-b", action="store_true", help="If toggled, input accepts a text file specifying all the PDB files need to be calculated (Each line is a PDB file name. If pH values are specified, followed with a space)")
args.add_argument("--output", "-o", help="Filename of generated output file. A file [shifts.csv] is generated by default. If in batch mode, you should specify the path for storing all the output files. Each output file has the same name as the input PDB file name.", default="shifts.csv")
args.add_argument("--worker", "-w", type=int, help="Number of CPU cores to use for parallel prediction in batch mode.", default=4)
args.add_argument("--shifty_only", "-y", "-Y", action="store_true", help="Only use UCBShift-Y (transfer prediction) module. Equivalent to executing UCBShift-Y directly with default settings")
args.add_argument("--shiftx_only", "-x", "-X", action="store_true", help="Only use UCBShift-X (machine learning) module. No alignment results will be utilized or calculated")
args.add_argument("--pH", "-pH", "-ph", type=float, help="pH value to be considered. Default is 5", default=5)
args.add_argument("--test", "-t", action="store_true", help="If toggled, use test mode for UCBShift-Y prediction")
args.add_argument("--models", help="Alternate location for models directory")
args = args.parse_args()
if args.models:
if not os.path.isdir(args.models):
raise ValueError("Directory {} specified by models does not exists".format(args.models))
ML_MODEL_PATH = args.models
if not args.batch:
preds = calc_sing_pdb(args.input, args.pH, TP=not args.shiftx_only, ML=not args.shifty_only, test=args.test)
preds.to_csv(args.output, index=None)
else:
inputs = []
with open(args.input) as f:
for line in f:
line_content = line.split()
if len(line_content) == 1:
# No pH values explicitly specified. Use the global pH values
line_content.append(args.pH)
else:
line_content[-1] = float(line_content[-1])
inputs.append(line_content)
# Decide saving folder
if args.output == "shifts.csv":
# No specific output path specified. Store all files in the current folder
SAVE_PREFIX = ""
else:
SAVE_PREFIX = args.output
if SAVE_PREFIX[-1] != "/":
SAVE_PREFIX = SAVE_PREFIX + "/"
for idx, item in enumerate(inputs):
preds = calc_sing_pdb(item[0], item[1], TP=not args.shiftx_only, ML=not args.shifty_only, test=args.test)
preds.to_csv(SAVE_PREFIX + os.path.basename(item[0]).replace(".pdb", ".csv"), index=None)
print("Finished prediction for %s (%d/%d)" % (item[0], idx + 1, len(inputs)))
print("Complete!")