#! /usr/bin/python
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import moving_averages
import tensorlayer as tl
from tensorlayer import logging
from tensorlayer.layers.core import Layer
__all__ = [
'LocalResponseNorm',
'BatchNorm', # FIXME: wthether to keep BatchNorm
'BatchNorm1d',
'BatchNorm2d',
'BatchNorm3d',
'InstanceNorm',
'InstanceNorm1d',
'InstanceNorm2d',
'InstanceNorm3d',
'LayerNorm',
'GroupNorm',
'SwitchNorm',
]
[docs]class LocalResponseNorm(Layer):
"""The :class:`LocalResponseNorm` layer is for Local Response Normalization.
See ``tf.nn.local_response_normalization`` or ``tf.nn.lrn`` for new TF version.
The 4-D input tensor is a 3-D array of 1-D vectors (along the last dimension), and each vector is normalized independently.
Within a given vector, each component is divided by the weighted square-sum of inputs within depth_radius.
Parameters
-----------
depth_radius : int
Depth radius. 0-D. Half-width of the 1-D normalization window.
bias : float
An offset which is usually positive and shall avoid dividing by 0.
alpha : float
A scale factor which is usually positive.
beta : float
An exponent.
name : None or str
A unique layer name.
"""
def __init__(
self,
depth_radius=None,
bias=None,
alpha=None,
beta=None,
name=None, #'lrn',
):
# super(LocalResponseNorm, self).__init__(prev_layer=prev_layer, name=name)
super().__init__(name)
self.depth_radius = depth_radius
self.bias = bias
self.alpha = alpha
self.beta = beta
logging.info(
"LocalResponseNorm %s: depth_radius: %s, bias: %s, alpha: %s, beta: %s" %
(self.name, str(depth_radius), str(bias), str(alpha), str(beta))
)
def build(self, inputs):
pass
def forward(self, inputs):
"""
prev_layer : :class:`Layer`
The previous layer with a 4D output shape.
"""
outputs = tf.nn.lrn(inputs, depth_radius=self.depth_radius, bias=self.bias, alpha=self.alpha, beta=self.beta)
return outputs
def _to_channel_first_bias(b):
"""Reshape [c] to [c, 1, 1]."""
channel_size = int(b.shape[0])
new_shape = (channel_size, 1, 1)
# new_shape = [-1, 1, 1] # doesn't work with tensorRT
return tf.reshape(b, new_shape)
def _bias_scale(x, b, data_format):
"""The multiplication counter part of tf.nn.bias_add."""
if data_format == 'NHWC':
return x * b
elif data_format == 'NCHW':
return x * b
else:
raise ValueError('invalid data_format: %s' % data_format)
def _bias_add(x, b, data_format):
"""Alternative implementation of tf.nn.bias_add which is compatiable with tensorRT."""
if data_format == 'NHWC':
return tf.add(x, b)
elif data_format == 'NCHW':
return tf.add(x, b)
else:
raise ValueError('invalid data_format: %s' % data_format)
def _compute_shape(tensors):
if isinstance(tensors, list):
shape_mem = [t.get_shape().as_list() for t in tensors]
else:
shape_mem = tensors.get_shape().as_list()
return shape_mem
def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, data_format, name=None):
"""Data Format aware version of tf.nn.batch_normalization."""
if data_format == 'channels_last':
mean = tf.reshape(mean, [1] * (len(x.shape) - 1) + [-1])
variance = tf.reshape(variance, [1] * (len(x.shape) - 1) + [-1])
offset = tf.reshape(offset, [1] * (len(x.shape) - 1) + [-1])
scale = tf.reshape(scale, [1] * (len(x.shape) - 1) + [-1])
elif data_format == 'channels_first':
mean = tf.reshape(mean, [1] + [-1] + [1] * (len(x.shape) - 2))
variance = tf.reshape(variance, [1] + [-1] + [1] * (len(x.shape) - 2))
offset = tf.reshape(offset, [1] + [-1] + [1] * (len(x.shape) - 2))
scale = tf.reshape(scale, [1] + [-1] + [1] * (len(x.shape) - 2))
else:
raise ValueError('invalid data_format: %s' % data_format)
with ops.name_scope(name, 'batchnorm', [x, mean, variance, scale, offset]):
inv = math_ops.rsqrt(variance + variance_epsilon)
if scale is not None:
inv *= scale
a = math_ops.cast(inv, x.dtype)
b = math_ops.cast(offset - mean * inv if offset is not None else -mean * inv, x.dtype)
# Return a * x + b with customized data_format.
# Currently TF doesn't have bias_scale, and tensorRT has bug in converting tf.nn.bias_add
# So we reimplemted them to allow make the model work with tensorRT.
# See https://github.com/tensorlayer/openpose-plus/issues/75 for more details.
df = {'channels_first': 'NCHW', 'channels_last': 'NHWC'}
return _bias_add(_bias_scale(x, a, df[data_format]), b, df[data_format])
[docs]class BatchNorm(Layer):
"""
The :class:`BatchNorm` is a batch normalization layer for both fully-connected and convolution outputs.
See ``tf.nn.batch_normalization`` and ``tf.nn.moments``.
Parameters
----------
decay : float
A decay factor for `ExponentialMovingAverage`.
Suggest to use a large value for large dataset.
epsilon : float
Eplison.
act : activation function
The activation function of this layer.
is_train : boolean
Is being used for training or inference.
beta_init : initializer or None
The initializer for initializing beta, if None, skip beta.
Usually you should not skip beta unless you know what happened.
gamma_init : initializer or None
The initializer for initializing gamma, if None, skip gamma.
When the batch normalization layer is use instead of 'biases', or the next layer is linear, this can be
disabled since the scaling can be done by the next layer. see `Inception-ResNet-v2 <https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py>`__
moving_mean_init : initializer or None
The initializer for initializing moving mean, if None, skip moving mean.
moving_var_init : initializer or None
The initializer for initializing moving var, if None, skip moving var.
num_features: int
Number of features for input tensor. Useful to build layer if using BatchNorm1d, BatchNorm2d or BatchNorm3d,
but should be left as None if using BatchNorm. Default None.
data_format : str
channels_last 'channel_last' (default) or channels_first.
name : None or str
A unique layer name.
Examples
---------
With TensorLayer
>>> net = tl.layers.Input([None, 50, 50, 32], name='input')
>>> net = tl.layers.BatchNorm()(net)
Notes
-----
The :class:`BatchNorm` is universally suitable for 3D/4D/5D input in static model, but should not be used
in dynamic model where layer is built upon class initialization. So the argument 'num_features' should only be used
for subclasses :class:`BatchNorm1d`, :class:`BatchNorm2d` and :class:`BatchNorm3d`. All the three subclasses are
suitable under all kinds of conditions.
References
----------
- `Source <https://github.com/ry/tensorflow-resnet/blob/master/resnet.py>`__
- `stackoverflow <http://stackoverflow.com/questions/38312668/how-does-one-do-inference-with-batch-normalization-with-tensor-flow>`__
"""
def __init__(
self,
decay=0.9,
epsilon=0.00001,
act=None,
is_train=False,
beta_init=tl.initializers.zeros(),
gamma_init=tl.initializers.random_normal(mean=1.0, stddev=0.002),
moving_mean_init=tl.initializers.zeros(),
moving_var_init=tl.initializers.zeros(),
num_features=None,
data_format='channels_last',
name=None,
):
super(BatchNorm, self).__init__(name=name, act=act)
self.decay = decay
self.epsilon = epsilon
self.data_format = data_format
self.beta_init = beta_init
self.gamma_init = gamma_init
self.moving_mean_init = moving_mean_init
self.moving_var_init = moving_var_init
self.num_features = num_features
self.axes = None
if num_features is not None:
self.build(None)
self._built = True
if self.decay < 0.0 or 1.0 < self.decay:
raise ValueError("decay should be between 0 to 1")
logging.info(
"BatchNorm %s: decay: %f epsilon: %f act: %s is_train: %s" %
(self.name, decay, epsilon, self.act.__name__ if self.act is not None else 'No Activation', is_train)
)
def __repr__(self):
actstr = self.act.__name__ if self.act is not None else 'No Activation'
s = ('{classname}(num_features={num_features}, decay={decay}' ', epsilon={epsilon}')
s += (', ' + actstr)
if self.name is not None:
s += ', name="{name}"'
s += ')'
return s.format(classname=self.__class__.__name__, **self.__dict__)
def _get_param_shape(self, inputs_shape):
if self.data_format == 'channels_last':
axis = -1
elif self.data_format == 'channels_first':
axis = 1
else:
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
channels = inputs_shape[axis]
params_shape = [channels]
return params_shape
def _check_input_shape(self, inputs):
inputs_shape = _compute_shape(inputs)
if len(inputs_shape) <= 1:
raise ValueError('expected input at least 2D, but got {}D input'.format(inputs.ndim))
def build(self, inputs_shape):
params_shape = [self.num_features] if self.num_features is not None else self._get_param_shape(inputs_shape)
self.beta, self.gamma = None, None
if self.beta_init:
self.beta = self._get_weights("beta", shape=params_shape, init=self.beta_init)
if self.gamma_init:
self.gamma = self._get_weights("gamma", shape=params_shape, init=self.gamma_init)
self.moving_mean = self._get_weights(
"moving_mean", shape=params_shape, init=self.moving_mean_init, trainable=False
)
self.moving_var = self._get_weights(
"moving_var", shape=params_shape, init=self.moving_var_init, trainable=False
)
def forward(self, inputs):
self._check_input_shape(inputs)
self.channel_axis = len(inputs.shape) - 1 if self.data_format == 'channels_last' else 1
if self.axes is None:
self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis]
mean, var = tf.nn.moments(inputs, self.axes, keepdims=False)
if self.is_train:
# update moving_mean and moving_var
self.moving_mean = moving_averages.assign_moving_average(
self.moving_mean, mean, self.decay, zero_debias=False
)
self.moving_var = moving_averages.assign_moving_average(self.moving_var, var, self.decay, zero_debias=False)
outputs = batch_normalization(inputs, mean, var, self.beta, self.gamma, self.epsilon, self.data_format)
else:
outputs = batch_normalization(
inputs, self.moving_mean, self.moving_var, self.beta, self.gamma, self.epsilon, self.data_format
)
if self.act:
outputs = self.act(outputs)
return outputs
[docs]class BatchNorm1d(BatchNorm):
"""The :class:`BatchNorm1d` applies Batch Normalization over 2D/3D input (a mini-batch of 1D
inputs (optional) with additional channel dimension), of shape (N, C) or (N, L, C) or (N, C, L).
See more details in :class:`BatchNorm`.
Examples
---------
With TensorLayer
>>> # in static model, no need to specify num_features
>>> net = tl.layers.Input([None, 50, 32], name='input')
>>> net = tl.layers.BatchNorm1d()(net)
>>> # in dynamic model, build by specifying num_features
>>> conv = tl.layers.Conv1d(32, 5, 1, in_channels=3)
>>> bn = tl.layers.BatchNorm1d(num_features=32)
"""
def _check_input_shape(self, inputs):
inputs_shape = _compute_shape(inputs)
if len(inputs_shape) != 2 and len(inputs_shape) != 3:
raise ValueError('expected input to be 2D or 3D, but got {}D input'.format(inputs.ndim))
[docs]class BatchNorm2d(BatchNorm):
"""The :class:`BatchNorm2d` applies Batch Normalization over 4D input (a mini-batch of 2D
inputs with additional channel dimension) of shape (N, H, W, C) or (N, C, H, W).
See more details in :class:`BatchNorm`.
Examples
---------
With TensorLayer
>>> # in static model, no need to specify num_features
>>> net = tl.layers.Input([None, 50, 50, 32], name='input')
>>> net = tl.layers.BatchNorm2d()(net)
>>> # in dynamic model, build by specifying num_features
>>> conv = tl.layers.Conv2d(32, (5, 5), (1, 1), in_channels=3)
>>> bn = tl.layers.BatchNorm2d(num_features=32)
"""
def _check_input_shape(self, inputs):
inputs_shape = _compute_shape(inputs)
if len(inputs_shape) != 4:
raise ValueError('expected input to be 4D, but got {}D input'.format(inputs.ndim))
[docs]class BatchNorm3d(BatchNorm):
"""The :class:`BatchNorm3d` applies Batch Normalization over 5D input (a mini-batch of 3D
inputs with additional channel dimension) with shape (N, D, H, W, C) or (N, C, D, H, W).
See more details in :class:`BatchNorm`.
Examples
---------
With TensorLayer
>>> # in static model, no need to specify num_features
>>> net = tl.layers.Input([None, 50, 50, 50, 32], name='input')
>>> net = tl.layers.BatchNorm3d()(net)
>>> # in dynamic model, build by specifying num_features
>>> conv = tl.layers.Conv3d(32, (5, 5, 5), (1, 1), in_channels=3)
>>> bn = tl.layers.BatchNorm3d(num_features=32)
"""
def _check_input_shape(self, inputs):
inputs_shape = _compute_shape(inputs)
if len(inputs_shape) != 5:
raise ValueError('expected input to be 5D, but got {}D input'.format(inputs.ndim))
[docs]class InstanceNorm(Layer):
"""
The :class:`InstanceNorm` is an instance normalization layer for both fully-connected and convolution outputs.
See ``tf.nn.batch_normalization`` and ``tf.nn.moments``.
Parameters
-----------
act : activation function.
The activation function of this layer.
epsilon : float
Eplison.
beta_init : initializer or None
The initializer for initializing beta, if None, skip beta.
Usually you should not skip beta unless you know what happened.
gamma_init : initializer or None
The initializer for initializing gamma, if None, skip gamma.
When the instance normalization layer is use instead of 'biases', or the next layer is linear, this can be
disabled since the scaling can be done by the next layer. see `Inception-ResNet-v2 <https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py>`__
num_features: int
Number of features for input tensor. Useful to build layer if using InstanceNorm1d, InstanceNorm2d or InstanceNorm3d,
but should be left as None if using InstanceNorm. Default None.
data_format : str
channels_last 'channel_last' (default) or channels_first.
name : None or str
A unique layer name.
Examples
---------
With TensorLayer
>>> net = tl.layers.Input([None, 50, 50, 32], name='input')
>>> net = tl.layers.InstanceNorm()(net)
Notes
-----
The :class:`InstanceNorm` is universally suitable for 3D/4D/5D input in static model, but should not be used
in dynamic model where layer is built upon class initialization. So the argument 'num_features' should only be used
for subclasses :class:`InstanceNorm1d`, :class:`InstanceNorm2d` and :class:`InstanceNorm3d`. All the three subclasses are
suitable under all kinds of conditions.
"""
def __init__(
self, act=None, epsilon=0.00001, beta_init=tl.initializers.zeros(),
gamma_init=tl.initializers.random_normal(mean=1.0, stddev=0.002), num_features=None,
data_format='channels_last', name=None
):
super(InstanceNorm, self).__init__(name=name, act=act)
self.epsilon = epsilon
self.beta_init = beta_init
self.gamma_init = gamma_init
self.num_features = num_features
self.data_format = data_format
if num_features is not None:
if not isinstance(self, InstanceNorm1d) and not isinstance(self, InstanceNorm2d) and not isinstance(
self, InstanceNorm3d):
raise ValueError(
"Please use InstanceNorm1d or InstanceNorm2d or InstanceNorm3d instead of InstanceNorm "
"if you want to specify 'num_features'."
)
self.build(None)
self._built = True
logging.info(
"InstanceNorm %s: epsilon: %f act: %s " %
(self.name, epsilon, self.act.__name__ if self.act is not None else 'No Activation')
)
def __repr__(self):
actstr = self.act.__name__ if self.act is not None else 'No Activation'
s = '{classname}(num_features=num_features, epsilon={epsilon}' + actstr
if self.name is not None:
s += ', name="{name}"'
s += ')'
return s.format(classname=self.__class__.__name__, **self.__dict__)
def _get_param_shape(self, inputs_shape):
if self.data_format == 'channels_last':
axis = len(inputs_shape) - 1
elif self.data_format == 'channels_first':
axis = 1
else:
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
channels = inputs_shape[axis]
params_shape = [1] * len(inputs_shape)
params_shape[axis] = channels
axes = [i for i in range(len(inputs_shape)) if i != 0 and i != axis]
return params_shape, axes
def build(self, inputs_shape):
params_shape, self.axes = self._get_param_shape(inputs_shape)
self.beta, self.gamma = None, None
if self.beta_init:
self.beta = self._get_weights("beta", shape=params_shape, init=self.beta_init)
if self.gamma_init:
self.gamma = self._get_weights("gamma", shape=params_shape, init=self.gamma_init)
def forward(self, inputs):
mean, var = tf.nn.moments(inputs, self.axes, keepdims=True)
outputs = batch_normalization(inputs, mean, var, self.beta, self.gamma, self.epsilon, self.data_format)
if self.act:
outputs = self.act(outputs)
return outputs
[docs]class InstanceNorm1d(InstanceNorm):
"""The :class:`InstanceNorm1d` applies Instance Normalization over 3D input (a mini-instance of 1D
inputs with additional channel dimension), of shape (N, L, C) or (N, C, L).
See more details in :class:`InstanceNorm`.
Examples
---------
With TensorLayer
>>> # in static model, no need to specify num_features
>>> net = tl.layers.Input([None, 50, 32], name='input')
>>> net = tl.layers.InstanceNorm1d()(net)
>>> # in dynamic model, build by specifying num_features
>>> conv = tl.layers.Conv1d(32, 5, 1, in_channels=3)
>>> bn = tl.layers.InstanceNorm1d(num_features=32)
"""
def _get_param_shape(self, inputs_shape):
if self.data_format == 'channels_last':
axis = 2
elif self.data_format == 'channels_first':
axis = 1
else:
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
if self.num_features is None:
channels = inputs_shape[axis]
else:
channels = self.num_features
params_shape = [1] * 3
params_shape[axis] = channels
axes = [i for i in range(3) if i != 0 and i != axis]
return params_shape, axes
[docs]class InstanceNorm2d(InstanceNorm):
"""The :class:`InstanceNorm2d` applies Instance Normalization over 4D input (a mini-instance of 2D
inputs with additional channel dimension) of shape (N, H, W, C) or (N, C, H, W).
See more details in :class:`InstanceNorm`.
Examples
---------
With TensorLayer
>>> # in static model, no need to specify num_features
>>> net = tl.layers.Input([None, 50, 50, 32], name='input')
>>> net = tl.layers.InstanceNorm2d()(net)
>>> # in dynamic model, build by specifying num_features
>>> conv = tl.layers.Conv2d(32, (5, 5), (1, 1), in_channels=3)
>>> bn = tl.layers.InstanceNorm2d(num_features=32)
"""
def _get_param_shape(self, inputs_shape):
if self.data_format == 'channels_last':
axis = 3
elif self.data_format == 'channels_first':
axis = 1
else:
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
if self.num_features is None:
channels = inputs_shape[axis]
else:
channels = self.num_features
params_shape = [1] * 4
params_shape[axis] = channels
axes = [i for i in range(4) if i != 0 and i != axis]
return params_shape, axes
[docs]class InstanceNorm3d(InstanceNorm):
"""The :class:`InstanceNorm3d` applies Instance Normalization over 5D input (a mini-instance of 3D
inputs with additional channel dimension) with shape (N, D, H, W, C) or (N, C, D, H, W).
See more details in :class:`InstanceNorm`.
Examples
---------
With TensorLayer
>>> # in static model, no need to specify num_features
>>> net = tl.layers.Input([None, 50, 50, 50, 32], name='input')
>>> net = tl.layers.InstanceNorm3d()(net)
>>> # in dynamic model, build by specifying num_features
>>> conv = tl.layers.Conv3d(32, (5, 5, 5), (1, 1), in_channels=3)
>>> bn = tl.layers.InstanceNorm3d(num_features=32)
"""
def _get_param_shape(self, inputs_shape):
if self.data_format == 'channels_last':
axis = 4
elif self.data_format == 'channels_first':
axis = 1
else:
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
if self.num_features is None:
channels = inputs_shape[axis]
else:
channels = self.num_features
params_shape = [1] * 5
params_shape[axis] = channels
axes = [i for i in range(5) if i != 0 and i != axis]
return params_shape, axes
# FIXME : not sure about the correctness, need testing
[docs]class LayerNorm(Layer):
"""
The :class:`LayerNorm` class is for layer normalization, see `tf.contrib.layers.layer_norm <https://www.tensorflow.org/api_docs/python/tf/contrib/layers/layer_norm>`__.
Parameters
----------
prev_layer : :class:`Layer`
The previous layer.
act : activation function
The activation function of this layer.
others : _
`tf.contrib.layers.layer_norm <https://www.tensorflow.org/api_docs/python/tf/contrib/layers/layer_norm>`__.
"""
def __init__(
self, #prev_layer,
center=True,
scale=True,
act=None,
# reuse=None,
# variables_collections=None,
# outputs_collections=None,
# trainable=True,
epsilon=1e-12,
begin_norm_axis=1,
begin_params_axis=-1,
beta_init=tl.initializers.zeros(),
gamma_init=tl.initializers.ones(),
data_format='channels_last',
name=None,
):
# super(LayerNorm, self).__init__(prev_layer=prev_layer, act=act, name=name)
super(LayerNorm, self).__init__(name, act=act)
self.center = center
self.scale = scale
self.epsilon = epsilon
self.begin_norm_axis = begin_norm_axis
self.begin_params_axis = begin_params_axis
self.beta_init = beta_init
self.gamma_init = gamma_init
self.data_format = data_format
logging.info(
"LayerNorm %s: act: %s" % (self.name, self.act.__name__ if self.act is not None else 'No Activation')
)
def build(self, inputs_shape):
params_shape = inputs_shape[self.begin_params_axis:]
self.beta, self.gamma = None, None
if self.center:
self.beta = self._get_weights("beta", shape=params_shape, init=self.beta_init)
if self.scale:
self.gamma = self._get_weights("gamma", shape=params_shape, init=self.gamma_init)
self.norm_axes = range(self.begin_norm_axis, len(inputs_shape))
def forward(self, inputs):
mean, var = tf.nn.moments(inputs, self.norm_axes, keepdims=True)
# compute layer normalization using batch_normalization function
outputs = batch_normalization(
inputs, mean, var, self.beta, self.gamma, self.epsilon, data_format=self.data_format
)
if self.act:
outputs = self.act(outputs)
return outputs
# with tf.compat.v1.variable_scope(name) as vs:
# self.outputs = tf.contrib.layers.layer_norm(
# self.inputs,
# center=center,
# scale=scale,
# activation_fn=self.act,
# reuse=reuse,
# variables_collections=variables_collections,
# outputs_collections=outputs_collections,
# trainable=trainable,
# begin_norm_axis=begin_norm_axis,
# begin_params_axis=begin_params_axis,
# scope='var',
# )
#
# variables = tf.compat.v1.get_collection("TF_GRAPHKEYS_VARIABLES", scope=vs.name)
#
# self._add_layers(self.outputs)
# self._add_params(variables)
[docs]class GroupNorm(Layer):
"""The :class:`GroupNorm` layer is for Group Normalization.
See `tf.contrib.layers.group_norm <https://www.tensorflow.org/api_docs/python/tf/contrib/layers/group_norm>`__.
Parameters
-----------
# prev_layer : :class:`Layer`
# The previous layer.
groups : int
The number of groups
act : activation function
The activation function of this layer.
epsilon : float
Eplison.
data_format : str
channels_last 'channel_last' (default) or channels_first.
name : None or str
A unique layer name
"""
def __init__(self, groups=32, epsilon=1e-06, act=None, data_format='channels_last', name=None): #'groupnorm'):
# super(GroupNorm, self).__init__(prev_layer=prev_layer, act=act, name=name)
super().__init__(name, act=act)
self.groups = groups
self.epsilon = epsilon
self.data_format = data_format
logging.info(
"GroupNorm %s: act: %s" % (self.name, self.act.__name__ if self.act is not None else 'No Activation')
)
def build(self, inputs_shape):
# shape = inputs.get_shape().as_list()
if len(inputs_shape) != 4:
raise Exception("This GroupNorm only supports 2D images.")
if self.data_format == 'channels_last':
channels = inputs_shape[-1]
self.int_shape = tf.concat(
[#tf.shape(input=self.inputs)[0:3],
inputs_shape[0:3],
tf.convert_to_tensor(value=[self.groups, channels // self.groups])], axis=0
)
elif self.data_format == 'channels_first':
channels = inputs_shape[1]
self.int_shape = tf.concat(
[
# tf.shape(input=self.inputs)[0:1],
inputs_shape[0:1],
tf.convert_to_tensor(value=[self.groups, channels // self.groups]),
# tf.shape(input=self.inputs)[2:4]
inputs_shape[2:4],
],
axis=0
)
else:
raise ValueError("data_format must be 'channels_last' or 'channels_first'.")
if self.groups > channels:
raise ValueError('Invalid groups %d for %d channels.' % (self.groups, channels))
if channels % self.groups != 0:
raise ValueError('%d channels is not commensurate with %d groups.' % (channels, self.groups))
if self.data_format == 'channels_last':
# mean, var = tf.nn.moments(x, [1, 2, 4], keep_dims=True)
self.gamma = self._get_weights("gamma", shape=channels, init=tl.initializers.ones())
# self.gamma = tf.compat.v1.get_variable('gamma', channels, initializer=tf.compat.v1.initializers.ones())
self.beta = self._get_weights("beta", shape=channels, init=tl.initializers.zeros())
# self.beta = tf.compat.v1.get_variable('beta', channels, initializer=tf.compat.v1.initializers.zeros())
elif self.data_format == 'channels_first':
# mean, var = tf.nn.moments(x, [2, 3, 4], keep_dims=True)
self.gamma = self._get_weights("gamma", shape=[1, channels, 1, 1], init=tl.initializers.ones())
# self.gamma = tf.compat.v1.get_variable('gamma', [1, channels, 1, 1], initializer=tf.compat.v1.initializers.ones())
self.beta = self._get_weights("beta", shape=[1, channels, 1, 1], init=tl.initializers.zeros())
# self.beta = tf.compat.v1.get_variable('beta', [1, channels, 1, 1], initializer=tf.compat.v1.initializers.zeros())
# self.add_weights([self.gamma, self.bata])
def forward(self, inputs):
x = tf.reshape(inputs, self.int_shape)
if self.data_format == 'channels_last':
mean, var = tf.nn.moments(x=x, axes=[1, 2, 4], keepdims=True)
elif self.data_format == 'channels_first':
mean, var = tf.nn.moments(x=x, axes=[2, 3, 4], keepdims=True)
else:
raise Exception("unknown data_format")
x = (x - mean) / tf.sqrt(var + self.epsilon)
outputs = tf.reshape(x, tf.shape(input=inputs)) * self.gamma + self.beta
if self.act:
outputs = self.act(outputs)
return outputs
[docs]class SwitchNorm(Layer):
"""
The :class:`SwitchNorm` is a switchable normalization.
Parameters
----------
act : activation function
The activation function of this layer.
epsilon : float
Eplison.
beta_init : initializer or None
The initializer for initializing beta, if None, skip beta.
Usually you should not skip beta unless you know what happened.
gamma_init : initializer or None
The initializer for initializing gamma, if None, skip gamma.
When the batch normalization layer is use instead of 'biases', or the next layer is linear, this can be
disabled since the scaling can be done by the next layer. see `Inception-ResNet-v2 <https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py>`__
moving_mean_init : initializer or None
The initializer for initializing moving mean, if None, skip moving mean.
data_format : str
channels_last 'channel_last' (default) or channels_first.
name : None or str
A unique layer name.
References
----------
- `Differentiable Learning-to-Normalize via Switchable Normalization <https://arxiv.org/abs/1806.10779>`__
- `Zhihu (CN) <https://zhuanlan.zhihu.com/p/39296570?utm_source=wechat_session&utm_medium=social&utm_oi=984862267107651584>`__
"""
def __init__(
self,
act=None,
epsilon=1e-5,
beta_init=tl.initializers.constant(0.0),
gamma_init=tl.initializers.constant(1.0),
moving_mean_init=tl.initializers.zeros(),
# beta_init=tf.compat.v1.initializers.constant(0.0),
# gamma_init=tf.compat.v1.initializers.constant(1.0),
# moving_mean_init=tf.compat.v1.initializers.zeros(),
data_format='channels_last',
name=None, #'switchnorm',
):
# super(SwitchNorm, self).__init__(prev_layer=prev_layer, act=act, name=name)
super().__init__(name, act=act)
self.epsilon = epsilon
self.beta_init = beta_init
self.gamma_init = gamma_init
self.moving_mean_init = moving_mean_init
self.data_format = data_format
logging.info(
"SwitchNorm %s: epsilon: %f act: %s" %
(self.name, epsilon, self.act.__name__ if self.act is not None else 'No Activation')
)
def build(self, inputs_shape):
if len(inputs_shape) != 4:
raise Exception("This SwitchNorm only supports 2D images.")
if self.data_format != 'channels_last':
raise Exception("This SwitchNorm only supports channels_last.")
ch = inputs_shape[-1]
self.gamma = self._get_weights("gamma", shape=[ch], init=self.gamma_init)
# self.gamma = tf.compat.v1.get_variable("gamma", [ch], initializer=gamma_init)
self.beta = self._get_weights("beta", shape=[ch], init=self.beta_init)
# self.beta = tf.compat.v1.get_variable("beta", [ch], initializer=beta_init)
self.mean_weight_var = self._get_weights("mean_weight", shape=[3], init=tl.initializers.constant(1.0))
# self.mean_weight_var = tf.compat.v1.get_variable("mean_weight", [3], initializer=tf.compat.v1.initializers.constant(1.0))
self.var_weight_var = self._get_weights("var_weight", shape=[3], init=tl.initializers.constant(1.0))
# self.var_weight_var = tf.compat.v1.get_variable("var_weight", [3], initializer=tf.compat.v1.initializers.constant(1.0))
# self.add_weights([self.gamma, self.beta, self.mean_weight_var, self.var_weight_var])
def forward(self, inputs):
batch_mean, batch_var = tf.nn.moments(x=inputs, axes=[0, 1, 2], keepdims=True)
ins_mean, ins_var = tf.nn.moments(x=inputs, axes=[1, 2], keepdims=True)
layer_mean, layer_var = tf.nn.moments(x=inputs, axes=[1, 2, 3], keepdims=True)
mean_weight = tf.nn.softmax(self.mean_weight_var)
var_weight = tf.nn.softmax(self.var_weight_var)
mean = mean_weight[0] * batch_mean + mean_weight[1] * ins_mean + mean_weight[2] * layer_mean
var = var_weight[0] * batch_var + var_weight[1] * ins_var + var_weight[2] * layer_var
inputs = (inputs - mean) / (tf.sqrt(var + self.epsilon))
outputs = inputs * self.gamma + self.beta
if self.act:
outputs = self.act(outputs)
return outputs