-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathk_maxpooling.py
27 lines (22 loc) · 983 Bytes
/
k_maxpooling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from keras.engine import Layer, InputSpec
from keras.layers import Flatten
import tensorflow as tf
class KMaxPooling(Layer):
"""
K-max pooling layer that extracts the k-highest activations from a sequence (2nd dimension).
TensorFlow backend.
"""
def __init__(self, k=1, sorted=True, **kwargs):
super().__init__(**kwargs)
self.input_spec = InputSpec(ndim=3)
self.k = k
self.sorted = sorted
def compute_output_shape(self, input_shape):
return (input_shape[0], self.k, input_shape[2])
def call(self, inputs):
# swap last two dimensions since top_k will be applied along the last dimension
shifted_inputs = tf.transpose(inputs, [0, 2, 1])
# extract top_k, returns two tensors [values, indices]
top_k = tf.nn.top_k(shifted_inputs, k=self.k, sorted=self.sorted)[0]
# return flattened output
return tf.transpose(top_k, [0,2,1])