Source code for tensorlayer.layers.importer

#! /usr/bin/python
# -*- coding: utf-8 -*-

import tensorflow as tf

from tensorlayer.layers.core import Layer
from tensorlayer.layers.core import TF_GRAPHKEYS_VARIABLES

from tensorlayer import logging

from tensorlayer.decorators import deprecated
from tensorlayer.decorators import deprecated_alias

__all__ = [
    'SlimNetsLayer',
    'KerasLayer',
    'EstimatorLayer',
]


[docs]class SlimNetsLayer(Layer): """A layer that merges TF-Slim models into TensorLayer. Models can be found in `slim-model <https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models>`__, see Inception V3 example on `Github <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_inceptionV3_tfslim.py>`__. Parameters ---------- prev_layer : :class:`Layer` Previous layer. slim_layer : a slim network function The network you want to stack onto, end with ``return net, end_points``. slim_args : dictionary The arguments for the slim model. name : str A unique layer name. Notes ----- - As TF-Slim stores the layers as dictionary, the ``all_layers`` in this network is not in order ! Fortunately, the ``all_params`` are in order. """ @deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release def __init__( self, prev_layer, slim_layer, slim_args=None, name='tfslim_layer', ): if slim_layer is None: raise ValueError("slim layer is None") super(SlimNetsLayer, self).__init__(prev_layer=prev_layer, slim_args=slim_args, name=name) logging.info("SlimNetsLayer %s: %s" % (self.name, slim_layer.__name__)) # with tf.variable_scope(name) as vs: # net, end_points = slim_layer(self.inputs, **slim_args) # slim_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) with tf.variable_scope(name): self.outputs, end_points = slim_layer(self.inputs, **self.slim_args) slim_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=self.name) if slim_variables == []: raise RuntimeError( "No variables found under %s : the name of SlimNetsLayer should be matched with the begining of the ckpt file.\n" "see tutorial_inceptionV3_tfslim.py for more details" % self.name ) slim_layers = [] for v in end_points.values(): slim_layers.append(v) self._add_layers(slim_layers) self._add_params(slim_variables)
@deprecated( date="2018-06-30", instructions="This layer will be deprecated soon as :class:`LambdaLayer` can do the same thing" ) class KerasLayer(Layer): """A layer to import Keras layers into TensorLayer. Example can be found here `tutorial_keras.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_keras.py>`__. Parameters ---------- prev_layer : :class:`Layer` Previous layer keras_layer : function A tensor in tensor out function for building model. keras_args : dictionary The arguments for the `keras_layer`. name : str A unique layer name. """ @deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release def __init__( self, prev_layer, keras_layer, keras_args=None, name='keras_layer', ): super(KerasLayer, self).__init__(prev_layer=prev_layer, keras_args=keras_args, name=name) logging.info("KerasLayer %s: %s" % (self.name, keras_layer)) logging.warning("This API will be removed, please use LambdaLayer instead.") with tf.variable_scope(name) as vs: self.outputs = keras_layer(self.inputs, **self.keras_args) variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) self._add_layers(self.outputs) self._add_params(variables) @deprecated( date="2018-06-30", instructions="This layer will be deprecated soon as :class:`LambdaLayer` can do the same thing" ) class EstimatorLayer(Layer): """A layer that accepts a user-defined model. It is similar with :class:`KerasLayer`, see `tutorial_keras.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_keras.py>`__. Parameters ---------- prev_layer : :class:`Layer` Previous layer model_fn : function A tensor in tensor out function for building model. layer_args : dictionary The arguments for the `model_fn`. name : str A unique layer name. """ @deprecated_alias( layer='prev_layer', args='layer_args', end_support_version=1.9 ) # TODO remove this line for the 1.9 release def __init__( self, prev_layer, model_fn, layer_args=None, name='estimator_layer', ): super(EstimatorLayer, self).__init__(prev_layer=prev_layer, layer_args=layer_args, name=name) logging.info("EstimatorLayer %s: %s" % (self.name, model_fn)) if model_fn is None: raise ValueError('model fn is None') logging.warning("This API will be removed, please use LambdaLayer instead.") with tf.variable_scope(name) as vs: self.outputs = model_fn(self.inputs, **self.layer_args) variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) self._add_layers(self.outputs) self._add_params(variables)