forked from haiderstats/ISDEvaluation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathanalysisMaster.R
374 lines (346 loc) · 20.2 KB
/
analysisMaster.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
#### File Information #####################################################################################################################
#File Name: analysisMaster.R
#Date Created: May 26, 2018
#Author: Humza Haider
#Email: [email protected]
### General Comments ######################################################################################################################
#This file can act as a master file to analyze a given dataset with all modeling techniques and evaluation metrics.
### Functions #############################################################################################################################
## Function 1: analysisMaster = function(survivalDataset, numberOfFolds =5,
# CoxKP = T,CoxKPEN = T, KaplanMeier = T, RSFModel = T, AFTModel = T, MTLRModel =T, #Models
# DCal = T, OneCal = T, Concor = T, L1Measure = T, BrierInt = T, BrierSingle = T, #Evaluations
# DCalBins = 10, OneCalTime = NULL, concordanceTies = "Risk", #Evaluation args
# SingleBrierTime = NULL, IntegratedBrierTimes = NULL, numBrierPoints = 1000, Ltype = "Margin",
# Llog = F, typeOneCal = "DN", oneCalBuckets = 10, survivalPredictionMethod = "Median",
# AFTDistribution = "weibull", #Model args,
# FS = T, imputeZero=T, verbose = T # Misc args)
#Inputs:
# survivalDataset - This is the dataset one wishes to analyze. This must include 'time', 'delta', and at least 1 more feature. No default.
# numberOfFolds - The number of desired cross-validation folds. Default is 5.
# CoxKP, CoxKPEN, KaplanMeier, RSFModel, AFTModel, MTLRModel: Booleans specifying whether or not to run that model. Default is TRUE.
# DCal, OneCal, Concor, L1Measure, BrierSingle, BrierInt: Booleans specifying whether or not to run that evaluation metric. Default is TRUE.
# DCalBins: Number of bins for D-Calibration. Default is 10.
# OneCalTime: An int specifying the time to evaluate 1-Calibration. If left as NULL but OneCal = TRUE, then the 10th, 25th, 50th, 75th,
# and 90th percentiless of all event times are used. Default is NULL.
# concordanceTies: A string ("None", "Time", "Risk","All") indicating how to handle ties in concordance. Default is "Risk".
# SingleBrierTime: The time to evaluate the Brier Score. If left as null, the 50th percentile of all event times is used. Default is NULL.
# IntegratedBrierTimes: A 2 length vector (e.g. c(0,100)) specifying the lower and upper bounds on the integrated Brier score. If NULL then
# the default is 0 as a lower bound and the max event time of the entire dataset is used as an upper bound. Default is NULL.
# numBrierPoints: The number of points to evaluate the integrated Brier score. A simple trapezoidal numerical approximation is used. Default
# is 1000 points.
# Ltype: The type of L1-loss. Must be one of "Uncensored","Hinge", or "Margin". Default is "Margin".
# Llog: A boolean specifying whether or not to use log-L1 metric. Default is FALSE.
# typeOneCal: A string indicating the type of 1-Calibrtion to use. Must be one of "DN" or "Uncensored". Default is "DN".
# oneCalBuckets: An int specifying number of bins for 1-Calibration. Default is 10.
# survivalPredictionMethod: The way in which to estimate average surival times. Must be one of "Mean" or "Median". Default is "Median".
# AFTDistribution: The distribution to use for AFT, default is "weibull". Must be one of "weibull","exponential","lognormal","gaussian",
# "loglogistic","logistic".
# FS: A boolean specifying whether or not to use feature selection. Default is TRUE.
# imputeZero: A boolean specifying whether 0 valued times should be imputed (AFT breaks for 0 valued times). If TRUE then 0 valued times are
# imputed to half the minimum non-zero time. Default is TRUE.
# verbose: A boolean specifying whether or not to return results and progress information.
#Output: A list of (3) items:
#(1) datasetUsed: This is the dataset that is actually used post feature selection but pre normalization and imputation. datasetUsed
#will have all the patients who had acceptable time and delta values and the features that were selected.
#(2) survivalCurves: This is a list containing the survival curves for all patients for each model that was tested.
#(3) results: This is a dataframe containing all the evaluation results with specified model and fold number. Additionally the sample size
#feature size, and censoring percnetage are returned. Notice that the feature sizes before and after one hot encoding are returned.
#If none of the features were factors then NumFeatures should equal NumFeaturesOneHot.
#Note that survivalCurves can be plotted by plotSurvivalCurves().
## Function 2: getSurvivalCurves()
# coxTimes, coxENTimes, kmTimes, aftTimes, rsfTimes, mtlrTimes - The times used for prediction of each model.
# CoxKP = T,CoxKPEN=T, KaplanMeier = T, RSFModel = T, AFTModel = T, MTLRModel =T: The models used in analysisMaster.
# combinedTestResults: A List containing all model survival curves.
# numberOfFolds: Number of folds for cross validation.
# originalIndexing: The original indexing prior to CV folds.
#Output: The survival curves of all survival models for all test patients.
#Usage: This is a helper function for analysisMaster(). This is used to get the survival curves for each model for each patient.
### Code ##################################################################################################################################
#Data processing files:
source("ValidateCleanCV/validateAndClean.R")
source("ValidateCleanCV/createFoldsAndNormalize.R")
#Modeling files:
source("Models/CoxPH_KP.R")
source("Models/KaplanMeier.R")
source("Models/RandomSurvivalForests.R")
source("Models/AcceleratedFailureTime.R")
source("Models/MTLR.R")
#Evaluation files:
source("Evaluations/DCalibration.R")
source("Evaluations/OneCalibration.R")
source("Evaluations/Concordance.R")
source("Evaluations/L1Measures.R")
source("Evaluations/BrierScore.R")
#Misc files:
source("FeatureSelection/FeatureSelection.R")
source("Plotting/plotSurvivalCurves.R")
analysisMaster = function(survivalDataset, numberOfFolds =5,
CoxKP = T,CoxKPEN = T, KaplanMeier = T, RSFModel = T, AFTModel = T, MTLRModel =T, #Models
DCal = T, OneCal = T, Concor = T, L1Measure = T, BrierInt = T, BrierSingle = T, #Evaluations
DCalBins = 10, OneCalTime = NULL, concordanceTies = "Risk", #Evaluation args
SingleBrierTime = NULL, IntegratedBrierTimes = NULL, numBrierPoints = 1000, Ltype = "Margin", #Evaluation args
Llog = F, typeOneCal = "DN", oneCalBuckets = 10, survivalPredictionMethod = "Median", #Evaluation args
AFTDistribution = "weibull", #Model args,
FS = T, imputeZero=T, verbose = T # Misc args
){
validatedData = validateAndClean(survivalDataset, imputeZero)
if(FS)
validatedData = FeatureSelection(validatedData, type = "UniCox")
foldsAndNormalizedData = createFoldsAndNormalize(validatedData, numberOfFolds)
originalIndexing = foldsAndNormalizedData[[1]]
normalizedData = foldsAndNormalizedData[[2]]
evaluationResults = data.frame()
combinedTestResults = list(Cox = list(),CoxEN = list(), KM = list(), AFT = list(), RSF = list(), MTLR = list())
coxTimes = NULL;coxENTimes = NULL; kmTimes = NULL; rsfTimes = NULL; aftTimes = NULL; mtlrTimes = NULL;
for(i in 1:numberOfFolds){
if(verbose){
print(Sys.time())
print(paste("Starting fold",i,"of", numberOfFolds, "total folds."))
}
#Models - We evaluate values to NULL so we can pass them to evaluations, regardless if the models were ran or not.
coxMod = NULL;coxENMod =NULL; kmMod = NULL; rsfMod = NULL; aftMod = NULL; mtlrMod = NULL;
training = normalizedData[[1]][[i]]
testing = normalizedData[[2]][[i]]
if(verbose){
print(paste("Beginning model training."))
}
if(CoxKP){
if(verbose){
print("Starting Cox Proportional Hazards.")
}
coxMod = CoxPH_KP(training, testing)
if(length(coxMod) ==1){
combinedTestResults$Cox = list()
coxTimes = NULL
CoxKP = F
if(i > 1)
evaluationResults = with(evaluationResults,evaluationResults[-which(Model == "CoxKP"),])
}
else{
combinedTestResults$Cox[[i]] = coxMod
coxTimes = c(coxTimes,coxMod[[1]]$time)
}
}
if(CoxKPEN){
if(verbose){
print("Starting Cox Proportional Hazards - Elastic Net.")
}
coxENMod = CoxPH_KP(training, testing,ElasticNet = T)
combinedTestResults$CoxEN[[i]] = coxENMod
coxENTimes = c(coxENTimes,coxENMod[[1]]$time)
}
if(KaplanMeier){
if(verbose){
print("Starting Kaplan Meier.")
}
kmMod = KM(training, testing)
combinedTestResults$KM[[i]] = kmMod
kmTimes = c(kmTimes,kmMod[[1]]$time)
}
if(RSFModel){
if(verbose){
print("Starting Random Survival Forests.")
}
rsfMod = RSF(training, testing)
combinedTestResults$RSF[[i]] = rsfMod
rsfTimes = c(rsfTimes,rsfMod[[1]]$time)
}
if(AFTModel){
if(verbose){
print("Starting Accelerated Failure Time.")
}
aftMod = AFT(training, testing, AFTDistribution)
if(length(aftMod)==1){
combinedTestResults$AFT = list()
aftTimes = NULL
AFTModel = F
if(i >1)
evaluationResults = with(evaluationResults,evaluationResults[-which(Model == "AFT"),])
}
else{
combinedTestResults$AFT[[i]] = aftMod
aftTimes = c(aftTimes,aftMod[[1]]$time)
}
}
if(MTLRModel){
if(verbose){
print("Starting Multi-task Logistic Regression (PSSP).")
}
mtlrMod = MTLR(training, testing)
combinedTestResults$MTLR[[i]] = mtlrMod
mtlrTimes = c(mtlrTimes,mtlrMod[[1]]$time)
}
#Evaluations - Note that if evaluations are passed a NULL value they return a NULL.
DCalResults = NULL;OneCalResults = NULL;ConcordanceResults = NULL;
BrierResultsInt = NULL;BrierResultsSingle = NULL;L1Results = NULL; L2Results = NULL;
if(Concor){
if(verbose){
print("Staring Evaluation: Concordance")
}
coxConc = Concordance(coxMod, concordanceTies,survivalPredictionMethod)
coxENConc = Concordance(coxENMod, concordanceTies,survivalPredictionMethod)
kmConc = Concordance(kmMod, concordanceTies,survivalPredictionMethod)
rsfConc = Concordance(rsfMod, concordanceTies,survivalPredictionMethod)
aftConc = Concordance(aftMod, concordanceTies,survivalPredictionMethod)
mtlrConc = Concordance(mtlrMod, concordanceTies,survivalPredictionMethod)
ConcordanceResults = rbind(coxConc,coxENConc, kmConc, rsfConc, aftConc, mtlrConc)
}
if(BrierInt){
if(verbose){
print("Staring Evaluation: Brier Score- Integrated")
}
coxBrierInt = BrierScore(coxMod, type = "Integrated", numPoints = numBrierPoints, integratedBrierTimes = IntegratedBrierTimes)
coxENBrierInt = BrierScore(coxENMod, type = "Integrated", numPoints = numBrierPoints, integratedBrierTimes = IntegratedBrierTimes)
kmBrierInt = BrierScore(kmMod, type = "Integrated", numPoints = numBrierPoints, integratedBrierTimes = IntegratedBrierTimes)
rsfBrierInt = BrierScore(rsfMod, type = "Integrated",numPoints = numBrierPoints, integratedBrierTimes = IntegratedBrierTimes)
aftBrierInt = BrierScore(aftMod, type = "Integrated", numPoints = numBrierPoints, integratedBrierTimes = IntegratedBrierTimes)
mtlrBrierInt = BrierScore(mtlrMod, type = "Integrated", numPoints = numBrierPoints, integratedBrierTimes = IntegratedBrierTimes)
BrierResultsInt = rbind(coxBrierInt,coxENBrierInt, kmBrierInt, rsfBrierInt, aftBrierInt, mtlrBrierInt)
}
if(BrierSingle){
if(verbose){
print("Staring Evaluation: Brier Score - Single")
}
coxBrierSingle = BrierScore(coxMod, type = "Single", singleBrierTime =SingleBrierTime )
coxENBrierSingle = BrierScore(coxENMod, type = "Single", singleBrierTime =SingleBrierTime )
kmBrierSingle = BrierScore(kmMod, type = "Single", singleBrierTime =SingleBrierTime )
rsfBrierSingle = BrierScore(rsfMod, type = "Single", singleBrierTime =SingleBrierTime )
aftBrierSingle = BrierScore(aftMod, type = "Single", singleBrierTime =SingleBrierTime )
mtlrBrierSingle = BrierScore(mtlrMod, type = "Single", singleBrierTime =SingleBrierTime )
BrierResultsSingle = rbind(coxBrierSingle,coxENBrierSingle, kmBrierSingle, rsfBrierSingle, aftBrierSingle, mtlrBrierSingle)
}
if(L1Measure){
if(verbose){
print("Staring Evaluation: L1 Loss")
}
coxL1 = L1(coxMod, Ltype, Llog,survivalPredictionMethod)
coxENL1 = L1(coxENMod, Ltype, Llog,survivalPredictionMethod)
kmL1 = L1(kmMod, Ltype, Llog,survivalPredictionMethod)
rsfL1 = L1(rsfMod, Ltype, Llog,survivalPredictionMethod)
aftL1 = L1(aftMod, Ltype, Llog,survivalPredictionMethod)
mtlrL1 = L1(mtlrMod, Ltype, Llog,survivalPredictionMethod)
L1Results = rbind(coxL1,coxENL1,kmL1,rsfL1,aftL1,mtlrL1)
}
toAdd = as.data.frame(cbind(ConcordanceResults,
BrierResultsInt, BrierResultsSingle,L1Results))
metricsRan = c(Concor,BrierInt,BrierSingle, L1Measure)
names(toAdd) = c("Concordance",
"BrierInt","BrierSingle", "L1Loss")[metricsRan]
modelsRan = c(CoxKP,CoxKPEN, KaplanMeier, RSFModel, AFTModel, MTLRModel)
models = c("CoxKP","CoxKPEN","Kaplan-Meier","RSF","AFT", "MTLR")[modelsRan]
if(any(metricsRan)){
toAdd = cbind.data.frame(Model = models,FoldNumer = i, toAdd)
}else{
toAdd = cbind.data.frame(Model = models,FoldNumer = i)
}
evaluationResults = rbind.data.frame(evaluationResults, toAdd)
if(verbose){
print(evaluationResults)
}
}
if(DCal){
if(verbose){
print("Staring Evaluation: Cumulative D-Calibration")
}
coxDcal = DCalibrationCumulative(combinedTestResults$Cox,DCalBins)
coxENDcal = DCalibrationCumulative(combinedTestResults$CoxEN,DCalBins)
kmDcal = DCalibrationCumulative(combinedTestResults$KM,DCalBins)
rsfDcal = DCalibrationCumulative(combinedTestResults$RSF,DCalBins)
aftDcal = DCalibrationCumulative(combinedTestResults$AFT,DCalBins)
mtlrDcal = DCalibrationCumulative(combinedTestResults$MTLR,DCalBins)
DCalResults = c(coxDcal,coxENDcal, kmDcal, rsfDcal, aftDcal, mtlrDcal)
evaluationResults$DCalibration = rep(DCalResults, numberOfFolds)
}
if(OneCal){
if(verbose){
print("Staring Evaluation: Cumulative One-Calibration")
}
cox1cal = OneCalibrationCumulative(combinedTestResults$Cox, OneCalTime, typeOneCal, oneCalBuckets)
coxEN1cal = OneCalibrationCumulative(combinedTestResults$CoxEN, OneCalTime, typeOneCal, oneCalBuckets)
km1cal = OneCalibrationCumulative(combinedTestResults$KM, OneCalTime, typeOneCal, oneCalBuckets)
rsf1cal = OneCalibrationCumulative(combinedTestResults$RSF, OneCalTime, typeOneCal, oneCalBuckets)
aft1cal = OneCalibrationCumulative(combinedTestResults$AFT, OneCalTime, typeOneCal, oneCalBuckets)
mtlr1cal = OneCalibrationCumulative(combinedTestResults$MTLR, OneCalTime, typeOneCal, oneCalBuckets)
numTimes = max(sapply(list(cox1cal,coxEN1cal, km1cal, rsf1cal,aft1cal, mtlr1cal),length))
for(times in 1:numTimes){
varName = paste("OneCalibration_",times, sep="")
assign(varName,c(cox1cal[times],coxEN1cal[times], km1cal[times], rsf1cal[times],aft1cal[times], mtlr1cal[times]))
evaluationResults[varName] = rep(eval(parse(text=varName)), numberOfFolds)
}
if(verbose){
print(evaluationResults)
}
}
#We will add some basic information about the dataset.
evaluationResults$N = nrow(validatedData)
#Note we subtract 2 to not count `time` and `delta`.
evaluationResults$NumFeatures = ncol(training) - 2
evaluationResults$PercentCensored = sum(!validatedData$delta)/nrow(validatedData)
survivalCurves = getSurvivalCurves(coxTimes,coxENTimes, kmTimes, aftTimes, rsfTimes, mtlrTimes,
CoxKP,CoxKPEN, KaplanMeier, RSFModel, AFTModel, MTLRModel,
combinedTestResults, numberOfFolds,originalIndexing)
names(survivalCurves) = c("Cox","CoxEN","KM","AFT","RSF","MTLR")[c(CoxKP,CoxKPEN, KaplanMeier, AFTModel,RSFModel, MTLRModel)]
rownames(evaluationResults) = NULL
return(list(datasetUsed = validatedData, survivalCurves = survivalCurves, results = evaluationResults))
}
#This function combines survival curves across the folds into one dataframe (we must get predictions for all
#the times across all folds otherwise we cannot combine patients from different folds into a dataframe.)
getSurvivalCurves = function(coxTimes,coxENTimes, kmTimes, aftTimes, rsfTimes, mtlrTimes,
CoxKP = T,CoxKPEN=T, KaplanMeier = T, RSFModel = T, AFTModel = T, MTLRModel =T,
combinedTestResults, numberOfFolds, originalIndexing){
originalIndexOrder = order(unname(unlist(originalIndexing)))
if(!is.null(coxTimes))
coxTimes = sort(unique(coxTimes))
if(!is.null(coxENTimes))
coxENTimes = sort(unique(coxENTimes))
if(!is.null(kmTimes))
kmTimes = sort(unique(kmTimes))
if(!is.null(rsfTimes))
rsfTimes = sort(unique(rsfTimes))
if(!is.null(aftTimes))
aftTimes = sort(unique(aftTimes))
if(!is.null(mtlrTimes))
mtlrTimes = sort(unique(mtlrTimes))
models = c(CoxKP,CoxKPEN, KaplanMeier, AFTModel,RSFModel,MTLRModel)
allTimes = list(coxTimes,coxENTimes,kmTimes,aftTimes,rsfTimes,mtlrTimes)
survivalCurves = list()
count = 0
for(j in which(models)){
count =count+1
fullCurves = data.frame(row.names = 1:length(allTimes[[j]]))
for(i in 1:numberOfFolds){
#Index method -> fold -> survival curves
times = combinedTestResults[[j]][[i]][[1]]$time
maxTime = max(times)
curves = combinedTestResults[[j]][[i]][[1]][,-1]
timesToEvaluate = setdiff(allTimes[[j]],times)
#Here we are going to combine the times from all folds and fit a spline so all patients have predictions for all times
#across all folds.
fullCurves = cbind.data.frame(fullCurves,sapply(curves,
function(x){
curveSpline = splinefun(times,x,method='hyman')
maxSpline = curveSpline(maxTime)
curveSplineConstant = function(time){
timeToEval = ifelse(time > maxTime, maxTime,time)
toReturn = rep(NA,length(time))
toReturn[timeToEval== maxTime] = max(maxSpline,0)
toReturn[timeToEval !=maxTime] = curveSpline(timeToEval[timeToEval!=maxTime])
return(toReturn)
}
extraPoints =curveSplineConstant(timesToEvaluate)
toReturn = rep(NA, length(allTimes[[j]]))
originalIndex = which(!allTimes[[j]] %in% timesToEvaluate)
newIndex = which(allTimes[[j]] %in% timesToEvaluate)
toReturn[originalIndex] = x
toReturn[newIndex] = extraPoints
return(toReturn)
}
))
}
fullCurves = fullCurves[originalIndexOrder]
fullCurves = cbind.data.frame(allTimes[j], fullCurves)
colnames(fullCurves) = c("time",1:(ncol(fullCurves)-1))
survivalCurves[[count]] = fullCurves
}
return(survivalCurves)
}