Source code for tensorlayer.models.mobilenetv1

#! /usr/bin/python
# -*- coding: utf-8 -*-
"""MobileNet for ImageNet."""

import os

import tensorflow as tf

from tensorlayer import logging
from tensorlayer.files import (assign_weights, load_npz, maybe_download_and_extract)
from tensorlayer.layers import (BatchNorm, Conv2d, DepthwiseConv2d, Flatten, GlobalMeanPool2d, Input, Reshape)
from tensorlayer.models import Model

__all__ = [
    'MobileNetV1',
]

layer_names = [
    'conv', 'depth1', 'depth2', 'depth3', 'depth4', 'depth5', 'depth6', 'depth7', 'depth8', 'depth9', 'depth10',
    'depth11', 'depth12', 'depth13', 'globalmeanpool', 'reshape', 'out'
]
n_filters = [32, 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024]


def conv_block(n, n_filter, filter_size=(3, 3), strides=(1, 1), name='conv_block'):
    # ref: https://github.com/keras-team/keras/blob/master/keras/applications/mobilenet.py
    n = Conv2d(n_filter, filter_size, strides, b_init=None, name=name + '.conv')(n)
    n = BatchNorm(decay=0.99, act=tf.nn.relu6, name=name + '.batchnorm')(n)
    return n


def depthwise_conv_block(n, n_filter, strides=(1, 1), name="depth_block"):
    n = DepthwiseConv2d((3, 3), strides, b_init=None, name=name + '.depthwise')(n)
    n = BatchNorm(decay=0.99, act=tf.nn.relu6, name=name + '.batchnorm1')(n)
    n = Conv2d(n_filter, (1, 1), (1, 1), b_init=None, name=name + '.conv')(n)
    n = BatchNorm(decay=0.99, act=tf.nn.relu6, name=name + '.batchnorm2')(n)
    return n


def restore_params(network, path='models'):
    logging.info("Restore pre-trained parameters")
    maybe_download_and_extract(
        'mobilenet.npz', path, 'https://github.com/tensorlayer/pretrained-models/raw/master/models/',
        expected_bytes=25600116
    )  # ls -al
    params = load_npz(name=os.path.join(path, 'mobilenet.npz'))
    # for idx, net_weight in enumerate(network.all_weights):
    #     if 'batchnorm' in net_weight.name:
    #         params[idx] = params[idx].reshape(1, 1, 1, -1)
    assign_weights(params[:len(network.all_weights)], network)
    del params


[docs]def MobileNetV1(pretrained=False, end_with='out', name=None): """Pre-trained MobileNetV1 model (static mode). Input shape [?, 224, 224, 3], value range [0, 1]. Parameters ---------- pretrained : boolean Whether to load pretrained weights. Default False. end_with : str The end point of the model [conv, depth1, depth2 ... depth13, globalmeanpool, out]. Default ``out`` i.e. the whole model. name : None or str Name for this model. Examples --------- Classify ImageNet classes, see `tutorial_models_mobilenetv1.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_models_mobilenetv1.py>`__ >>> # get the whole model with pretrained weights >>> mobilenetv1 = tl.models.MobileNetV1(pretrained=True) >>> # use for inferencing >>> output = mobilenetv1(img1, is_train=False) >>> prob = tf.nn.softmax(output)[0].numpy() Extract features and Train a classifier with 100 classes >>> # get model without the last layer >>> cnn = tl.models.MobileNetV1(pretrained=True, end_with='reshape').as_layer() >>> # add one more layer and build new model >>> ni = Input([None, 224, 224, 3], name="inputs") >>> nn = cnn(ni) >>> nn = Conv2d(100, (1, 1), (1, 1), name='out')(nn) >>> nn = Flatten(name='flatten')(nn) >>> model = tl.models.Model(inputs=ni, outputs=nn) >>> # train your own classifier (only update the last layer) >>> train_params = model.get_layer('out').trainable_weights Returns ------- static MobileNetV1. """ ni = Input([None, 224, 224, 3], name="input") for i in range(len(layer_names)): if i == 0: n = conv_block(ni, n_filters[i], strides=(2, 2), name=layer_names[i]) elif layer_names[i] in ['depth2', 'depth4', 'depth6', 'depth12']: n = depthwise_conv_block(n, n_filters[i], strides=(2, 2), name=layer_names[i]) elif layer_names[i] == 'globalmeanpool': n = GlobalMeanPool2d(name='globalmeanpool')(n) elif layer_names[i] == 'reshape': n = Reshape([-1, 1, 1, 1024], name='reshape')(n) elif layer_names[i] == 'out': n = Conv2d(1000, (1, 1), (1, 1), name='out')(n) n = Flatten(name='flatten')(n) else: n = depthwise_conv_block(n, n_filters[i], name=layer_names[i]) if layer_names[i] == end_with: break network = Model(inputs=ni, outputs=n, name=name) if pretrained: restore_params(network) return network