Skip to content

Commit

Permalink
adding rank attr to gaussian noise layer to enable input shape determ…
Browse files Browse the repository at this point in the history
…ination for sup3r forward passes
  • Loading branch information
bnb32 committed Sep 5, 2024
1 parent 3bdbabf commit 8aa1aa7
Showing 1 changed file with 34 additions and 28 deletions.
62 changes: 34 additions & 28 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,14 @@ def __init__(self, pool_size, strides=None, padding='valid', sigma=1,
def _make_2D_gaussian_kernel(edge_len, sigma=1.):
"""Creates 2D gaussian kernel with side length `edge_len` and a sigma
of `sigma`
Parameters
----------
edge_len : int
Edge size of the kernel
sigma : float
Sigma parameter for gaussian distribution
Returns
-------
kernel : np.ndarray
Expand Down Expand Up @@ -213,10 +215,12 @@ def get_config(self):

def call(self, x):
"""Operates on x with the specified function
Parameters
----------
x : tf.Tensor
Input tensor
Returns
-------
x : tf.Tensor
Expand Down Expand Up @@ -252,6 +256,7 @@ def __init__(self, axis, mean=1, stddev=0.1):
"""

super().__init__()
self.rank = None
self._axis = axis
self._rand_shape = None
self._mean = tf.constant(mean, dtype=tf.dtypes.float32)
Expand All @@ -269,6 +274,7 @@ def build(self, input_shape):
"""
shape = np.ones(len(input_shape), dtype=np.int32)
shape[self._axis] = input_shape[self._axis]
self.rank = len(input_shape)
self._rand_shape = tf.constant(shape, dtype=tf.dtypes.int32)

def call(self, x):
Expand Down Expand Up @@ -351,7 +357,7 @@ def __init__(self, spatial_mult=1):
"""
Parameters
----------
spatial_multiplier : int
spatial_mult : int
Number of times to multiply the spatial dimensions. Note that the
spatial expansion is an un-packing of the feature dimension. For
example, if the input layer has shape (123, 5, 5, 16) with
Expand Down Expand Up @@ -435,14 +441,14 @@ def __init__(self, spatial_mult=1, temporal_mult=1,
"""
Parameters
----------
spatial_multiplier : int
spatial_mult : int
Number of times to multiply the spatial dimensions. Note that the
spatial expansion is an un-packing of the feature dimension. For
example, if the input layer has shape (123, 5, 5, 24, 16) with
multiplier=2 the output shape will be (123, 10, 10, 24, 4). The
input feature dimension must be divisible by the spatial multiplier
squared.
temporal_multiplier : int
temporal_mult : int
Number of times to multiply the temporal dimension. For example,
if the input layer has shape (123, 5, 5, 24, 2) with multiplier=2
the output shape will be (123, 5, 5, 48, 2).
Expand Down Expand Up @@ -603,18 +609,17 @@ def call(self, x):
if self._cache is None:
self._cache = x
return x
try:
out = tf.add(x, self._cache)
except Exception as e:
msg = ('Could not add SkipConnection "{}" data cache of '
'shape {} to input of shape {}.'
.format(self._name, self._cache.shape, x.shape))
logger.error(msg)
raise RuntimeError(msg) from e
else:
try:
out = tf.add(x, self._cache)
except Exception as e:
msg = ('Could not add SkipConnection "{}" data cache of '
'shape {} to input of shape {}.'
.format(self._name, self._cache.shape, x.shape))
logger.error(msg)
raise RuntimeError(msg) from e
else:
self._cache = None
return out
self._cache = None
return out


class SqueezeAndExcitation(tf.keras.layers.Layer):
Expand Down Expand Up @@ -834,7 +839,8 @@ def __init__(self, name=None):
"""
super().__init__(name=name)

def call(self, x, hi_res_adder):
@staticmethod
def call(x, hi_res_adder):
"""Adds hi-resolution data to the input tensor x in the middle of a
sup3r resolution network.
Expand Down Expand Up @@ -869,7 +875,8 @@ def __init__(self, name=None):
"""
super().__init__(name=name)

def call(self, x, hi_res_feature):
@staticmethod
def call(x, hi_res_feature):
"""Concatenates a hi-resolution feature to the input tensor x in the
middle of a sup3r resolution network.
Expand Down Expand Up @@ -940,7 +947,8 @@ class SigLin(tf.keras.layers.Layer):
y = x + 0.5 where x>=0.5
"""

def call(self, x):
@staticmethod
def call(x):
"""Operates on x with SigLin
Parameters
Expand Down Expand Up @@ -1002,8 +1010,7 @@ def build(self, input_shape):
def _logt(self, x):
if not self.inverse:
return tf.math.log(x + self.adder) * self.scalar
else:
return tf.math.exp(x / self.scalar) - self.adder
return tf.math.exp(x / self.scalar) - self.adder

def call(self, x):
"""Operates on x with (inverse) log transform
Expand All @@ -1021,16 +1028,15 @@ def call(self, x):

if self.idf is None:
return self._logt(x)
else:
out = []
for idf in range(x.shape[-1]):
if idf in self.idf:
out.append(self._logt(x[..., idf:idf + 1]))
else:
out.append(x[..., idf:idf + 1])
out = []
for idf in range(x.shape[-1]):
if idf in self.idf:
out.append(self._logt(x[..., idf:idf + 1]))
else:
out.append(x[..., idf:idf + 1])

out = tf.concat(out, -1, name='concat')
return out
out = tf.concat(out, -1, name='concat')
return out


class UnitConversion(tf.keras.layers.Layer):
Expand Down

0 comments on commit 8aa1aa7

Please sign in to comment.