-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclass_DeepHit.py
204 lines (148 loc) · 9.8 KB
/
class_DeepHit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
'''
This declare DeepHit architecture:
INPUTS:
- input_dims: dictionary of dimension information
> x_dim: dimension of features
> num_Event: number of competing events (this does not include censoring label)
> num_Category: dimension of time horizon of interest, i.e., |T| where T = {0, 1, ..., T_max-1}
: this is equivalent to the output dimension
- network_settings:
> h_dim_shared & num_layers_shared: number of nodes and number of fully-connected layers for the shared subnetwork
> h_dim_CS & num_layers_CS: number of nodes and number of fully-connected layers for the cause-specific subnetworks
> active_fn: 'relu', 'elu', 'tanh'
> initial_W: Xavier initialization is used as a baseline
LOSS FUNCTIONS:
- 1. loglikelihood (this includes log-likelihood of subjects who are censored)
- 2. rankding loss (this is calculated only for acceptable pairs; see the paper for the definition)
- 3. calibration loss (this is to reduce the calibration loss; this is not included in the paper version)
'''
import numpy as np
import tensorflow as tf
import random
from tf_slim import fully_connected as FC_Net
import model_scripts.utils_network as utils
### user-defined functions
#import untils_network as utils
_EPSILON = 1e-08
##### USER-DEFINED FUNCTIONS
def log(x):
return tf.math.log(x + 1e-8)
def div(x, y):
return tf.compat.v1.div(x, (y + 1e-8))
class Model_DeepHit:
def __init__(self, sess, name, input_dims, network_settings):
self.sess = sess
self.name = name
# INPUT DIMENSIONS
self.x_dim = input_dims['x_dim']
self.num_Event = input_dims['num_Event']
self.num_Category = input_dims['num_Category']
# NETWORK HYPER-PARMETERS
self.h_dim_shared = network_settings['h_dim_shared']
self.h_dim_CS = network_settings['h_dim_CS']
self.num_layers_shared = network_settings['num_layers_shared']
self.num_layers_CS = network_settings['num_layers_CS']
self.active_fn = network_settings['active_fn']
self.initial_W = network_settings['initial_W']
self.reg_W = tf.keras.regularizers.l2(l=0.5 * (1e-4))
self.reg_W_out = tf.keras.regularizers.l1(l=1e-4)
self._build_net()
def _build_net(self):
with tf.compat.v1.variable_scope(self.name):
#### PLACEHOLDER DECLARATION
self.mb_size = tf.compat.v1.placeholder(tf.int32, [], name='batch_size')
self.lr_rate = tf.compat.v1.placeholder(tf.float32, [], name='learning_rate')
self.keep_prob = tf.compat.v1.placeholder(tf.float32, [], name='keep_probability') #keeping rate
self.a = tf.compat.v1.placeholder(tf.float32, [], name='alpha')
self.b = tf.compat.v1.placeholder(tf.float32, [], name='beta')
self.c = tf.compat.v1.placeholder(tf.float32, [], name='gamma')
self.x = tf.compat.v1.placeholder(tf.float32, shape=[None, self.x_dim], name='inputs')
self.k = tf.compat.v1.placeholder(tf.float32, shape=[None, 1], name='labels') #event/censoring label (censoring:0)
self.t = tf.compat.v1.placeholder(tf.float32, shape=[None, 1], name='timetoevents')
self.fc_mask1 = tf.compat.v1.placeholder(tf.float32, shape=[None, self.num_Event, self.num_Category], name='mask1') #for Loss 1
self.fc_mask2 = tf.compat.v1.placeholder(tf.float32, shape=[None, self.num_Category], name='mask2') #for Loss 2 / Loss 3
##### SHARED SUBNETWORK w/ FCNETS
shared_out = utils.create_FCNet(self.x, self.num_layers_shared, self.h_dim_shared, self.active_fn, self.h_dim_shared, self.active_fn, self.initial_W, self.keep_prob, self.reg_W)
last_x = self.x #for residual connection
h = tf.concat([last_x, shared_out], axis=1)
#(num_layers_CS) layers for cause-specific (num_Event subNets)
out = []
for _ in range(self.num_Event):
cs_out = utils.create_FCNet(h, (self.num_layers_CS), self.h_dim_CS, self.active_fn, self.h_dim_CS, self.active_fn, self.initial_W, self.keep_prob, self.reg_W)
out.append(cs_out)
out = tf.stack(out, axis=1) # stack referenced on subject
out = tf.reshape(out, [-1, self.num_Event*self.h_dim_CS])
out = tf.nn.dropout(out, rate=1 - (self.keep_prob))
out = FC_Net(out, self.num_Event * self.num_Category, activation_fn=tf.nn.softmax,
weights_initializer=self.initial_W, weights_regularizer=self.reg_W_out, scope="Output")
self.out = tf.reshape(out, [-1, self.num_Event, self.num_Category])
##### GET LOSS FUNCTIONS
self.loss_Log_Likelihood() #get loss1: Log-Likelihood loss
self.loss_Ranking() #get loss2: Ranking loss
self.loss_Calibration() #get loss3: Calibration loss
self.LOSS_TOTAL = self.a*self.LOSS_1 + self.b*self.LOSS_2 + self.c*self.LOSS_3 + tf.compat.v1.losses.get_regularization_loss()
self.solver = tf.compat.v1.train.AdamOptimizer(learning_rate=self.lr_rate).minimize(self.LOSS_TOTAL)
### LOSS-FUNCTION 1 -- Log-likelihood loss
def loss_Log_Likelihood(self):
I_1 = tf.sign(self.k)
#for uncenosred: log P(T=t,K=k|x)
tmp1 = tf.reduce_sum(input_tensor=tf.reduce_sum(input_tensor=self.fc_mask1 * self.out, axis=2), axis=1, keepdims=True)
tmp1 = I_1 * log(tmp1)
#for censored: log \sum P(T>t|x)
tmp2 = tf.reduce_sum(input_tensor=tf.reduce_sum(input_tensor=self.fc_mask1 * self.out, axis=2), axis=1, keepdims=True)
tmp2 = (1. - I_1) * log(tmp2)
self.LOSS_1 = - tf.reduce_mean(input_tensor=tmp1 + 1.0*tmp2)
### LOSS-FUNCTION 2 -- Ranking loss
def loss_Ranking(self):
sigma1 = tf.constant(0.1, dtype=tf.float32)
eta = []
for e in range(self.num_Event):
one_vector = tf.ones_like(self.t, dtype=tf.float32)
I_2 = tf.cast(tf.equal(self.k, e+1), dtype = tf.float32) #indicator for event
I_2 = tf.linalg.tensor_diag(tf.squeeze(I_2))
tmp_e = tf.reshape(tf.slice(self.out, [0, e, 0], [-1, 1, -1]), [-1, self.num_Category]) #event specific joint prob.
R = tf.matmul(tmp_e, tf.transpose(a=self.fc_mask2)) #no need to divide by each individual dominator
# r_{ij} = risk of i-th pat based on j-th time-condition (last meas. time ~ event time) , i.e. r_i(T_{j})
diag_R = tf.reshape(tf.linalg.tensor_diag_part(R), [-1, 1])
R = tf.matmul(one_vector, tf.transpose(a=diag_R)) - R # R_{ij} = r_{j}(T_{j}) - r_{i}(T_{j})
R = tf.transpose(a=R) # Now, R_{ij} (i-th row j-th column) = r_{i}(T_{i}) - r_{j}(T_{i})
T = tf.nn.relu(tf.sign(tf.matmul(one_vector, tf.transpose(a=self.t)) - tf.matmul(self.t, tf.transpose(a=one_vector))))
# T_{ij}=1 if t_i < t_j and T_{ij}=0 if t_i >= t_j
T = tf.matmul(I_2, T) # only remains T_{ij}=1 when event occured for subject i
tmp_eta = tf.reduce_mean(input_tensor=T * tf.exp(-R/sigma1), axis=1, keepdims=True)
eta.append(tmp_eta)
eta = tf.stack(eta, axis=1) #stack referenced on subjects
eta = tf.reduce_mean(input_tensor=tf.reshape(eta, [-1, self.num_Event]), axis=1, keepdims=True)
self.LOSS_2 = tf.reduce_sum(input_tensor=eta) #sum over num_Events
### LOSS-FUNCTION 3 -- Calibration Loss
def loss_Calibration(self):
eta = []
for e in range(self.num_Event):
one_vector = tf.ones_like(self.t, dtype=tf.float32)
I_2 = tf.cast(tf.equal(self.k, e+1), dtype = tf.float32) #indicator for event
tmp_e = tf.reshape(tf.slice(self.out, [0, e, 0], [-1, 1, -1]), [-1, self.num_Category]) #event specific joint prob.
r = tf.reduce_sum(input_tensor=tmp_e * self.fc_mask2, axis=0) #no need to divide by each individual dominator
tmp_eta = tf.reduce_mean(input_tensor=(r - I_2)**2, axis=1, keepdims=True)
eta.append(tmp_eta)
eta = tf.stack(eta, axis=1) #stack referenced on subjects
eta = tf.reduce_mean(input_tensor=tf.reshape(eta, [-1, self.num_Event]), axis=1, keepdims=True)
self.LOSS_3 = tf.reduce_sum(input_tensor=eta) #sum over num_Events
def get_cost(self, DATA, MASK, PARAMETERS, keep_prob, lr_train):
(x_mb, k_mb, t_mb) = DATA
(m1_mb, m2_mb) = MASK
(alpha, beta, gamma) = PARAMETERS
return self.sess.run(self.LOSS_TOTAL,
feed_dict={self.x:x_mb, self.k:k_mb, self.t:t_mb, self.fc_mask1: m1_mb, self.fc_mask2:m2_mb,
self.a:alpha, self.b:beta, self.c:gamma,
self.mb_size: np.shape(x_mb)[0], self.keep_prob:keep_prob, self.lr_rate:lr_train})
def train(self, DATA, MASK, PARAMETERS, keep_prob, lr_train):
(x_mb, k_mb, t_mb) = DATA
(m1_mb, m2_mb) = MASK
(alpha, beta, gamma) = PARAMETERS
return self.sess.run([self.solver, self.LOSS_TOTAL],
feed_dict={self.x:x_mb, self.k:k_mb, self.t:t_mb, self.fc_mask1: m1_mb, self.fc_mask2:m2_mb,
self.a:alpha, self.b:beta, self.c:gamma,
self.mb_size: np.shape(x_mb)[0], self.keep_prob:keep_prob, self.lr_rate:lr_train})
def predict(self, x_test, keep_prob=1.0):
return self.sess.run(self.out, feed_dict={self.x: x_test, self.mb_size
: np.shape(x_test)[0], self.keep_prob: keep_prob})