# -*- coding: utf-8 -*-
from .core import *
[docs]class MultiplexerLayer(Layer):
"""
The :class:`MultiplexerLayer` selects inputs to be forwarded to output.
see `tutorial_mnist_multiplexer.py`.
Parameters
----------
layers : a list of :class:`Layer`
The input layers.
name : str
A unique layer name.
Attributes
----------
sel : placeholder
The placeholder takes an integer for selecting which layer to output.
Examples
--------
>>> x = tf.placeholder(tf.float32, shape=(None, 784), name='x')
>>> # define the network
>>> net_in = tl.layers.InputLayer(x, name='input')
>>> net_in = tl.layers.DropoutLayer(net_in, keep=0.8, name='drop1')
>>> # net 0
>>> net_0 = tl.layers.DenseLayer(net_in, n_units=800, act=tf.nn.relu, name='net0/relu1')
>>> net_0 = tl.layers.DropoutLayer(net_0, keep=0.5, name='net0/drop2')
>>> net_0 = tl.layers.DenseLayer(net_0, n_units=800, act=tf.nn.relu, name='net0/relu2')
>>> # net 1
>>> net_1 = tl.layers.DenseLayer(net_in, n_units=800, act=tf.nn.relu, name='net1/relu1')
>>> net_1 = tl.layers.DropoutLayer(net_1, keep=0.8, name='net1/drop2')
>>> net_1 = tl.layers.DenseLayer(net_1, n_units=800, act=tf.nn.relu, name='net1/relu2')
>>> net_1 = tl.layers.DropoutLayer(net_1, keep=0.8, name='net1/drop3')
>>> net_1 = tl.layers.DenseLayer(net_1, n_units=800, act=tf.nn.relu, name='net1/relu3')
>>> # multiplexer
>>> net_mux = tl.layers.MultiplexerLayer(layers=[net_0, net_1], name='mux')
>>> network = tl.layers.ReshapeLayer(net_mux, shape=(-1, 800), name='reshape')
>>> network = tl.layers.DropoutLayer(network, keep=0.5, name='drop3')
>>> # output layer
>>> network = tl.layers.DenseLayer(network, n_units=10, act=tf.identity, name='output')
"""
def __init__(self, layers, name='mux_layer'):
Layer.__init__(self, prev_layer=layers, name=name)
self.n_inputs = len(layers)
self.inputs = []
for l in layers:
self.inputs.append(l.outputs)
try: # TF1.0
all_inputs = tf.stack(self.inputs, name=name) # pack means concat a list of tensor in a new dim # 1.2
except Exception:
all_inputs = tf.pack(self.inputs, name=name) # pack means concat a list of tensor in a new dim # 1.2
logging.info("MultiplexerLayer %s: n_inputs:%d" % (self.name, self.n_inputs))
self.sel = tf.placeholder(tf.int32)
self.outputs = tf.gather(all_inputs, self.sel, name=name) # [sel, :, : ...] # 1.2
# logging.info(self.outputs, vars(self.outputs))
# # tf.reshape(self.outputs, shape=)
# exit()
# # the same with ConcatLayer
# self.all_layers = list(layers[0].all_layers)
# self.all_params = list(layers[0].all_params)
# self.all_drop = dict(layers[0].all_drop)
#
# for i in range(1, len(layers)):
# self.all_layers.extend(list(layers[i].all_layers))
# self.all_params.extend(list(layers[i].all_params))
# self.all_drop.update(dict(layers[i].all_drop))
#
# self.all_layers = list_remove_repeat(self.all_layers)
# self.all_params = list_remove_repeat(self.all_params)
# # self.all_drop = list_remove_repeat(self.all_drop)
self.all_layers.append(self.outputs)