Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options to train and export TFLite compatible models #157

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

GreenAppers
Copy link

@GreenAppers GreenAppers commented Mar 15, 2019

These are the changes needed to get pix2pix-tensorflow running on mobile.

tf.layers.batch_normalization() on TFLite requires training=False, and that the model was trained with training=True and batch_size > 1. Also TFLite has no tf.tanh(), tf.image.convert_image_dtype(), or others.

Updating the batch_normalization Tensorflow variables for training=False (which aren't trainable vars) requires the UPDATE_OPS dependencies.

I have to re-train the model using tf.contrib.layers.instance_norm() instead of tf.layers.batch_normalization() with batch_size=1. It seems to work good. Is it the exact same thing?

@dbx0
Copy link

dbx0 commented Sep 23, 2019

I converted my model and now it only outputs empty black files. Any tips?

@GreenAppers
Copy link
Author

GreenAppers commented Sep 23, 2019

You would have to retrain your model using the code from this PR. The issue is that the original implementation uses batch normalization with batch_size=1. This is a degenerate case, according to my understanding of the definitions, where batch normalization becomes instance normalization. However I can't speak to how Tensorflow implements these operations.

At any rate, TFLite won't accept batch_normalization with batch_size=1 and requires instance_norm instead. This requires retraining.

A smart conversion tool could probably keep all the weights, but change some ID field in the Tensorflow protobuf representing batch_norm to instance_norm.

Here are the command lines I used to train and export a TFLite model:

python pix2pix-tensorflow/pix2pix.py \
  --mode train \
  --max_epochs 200 \
  --save_freq 2000 \
  --norm_type tflite_compatible \
  --input_dir contours2cats \
  --output_dir contours2cats_train \
  --checkpoint contours2cats_train

python pix2pix-tensorflow/pix2pix.py \
  --mode export \
  --export_format tflite \
  --norm_type tflite_compatible \
  --checkpoint contours2cats_train \
  --output_dir contours2cats_export

@mrgloom
Copy link

mrgloom commented Jan 15, 2020

Do we need to add update_ops to train_op also?

From documentation:

  Note: when training, the moving_mean and moving_variance need to be updated.
  By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
  need to be executed alongside the `train_op`. Also, be sure to add any
  batch_normalization ops before getting the update_ops collection. Otherwise,
  update_ops will be empty, and training/inference will not work properly. For
  example:

    x_norm = tf.compat.v1.layers.batch_normalization(x, training=training)

    # ...

    update_ops = tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS)
    train_op = optimizer.minimize(loss)
    train_op = tf.group([train_op, update_ops])

@mrgloom
Copy link

mrgloom commented Jan 15, 2020

I have to re-train the model using tf.contrib.layers.instance_norm() instead of tf.layers.batch_normalization() with batch_size=1. It seems to work good. Is it the exact same thing?

Yes, tf.layers.batch_normalization is the same as tf.contrib.layers.instance_norm when batch_size=1

Test: max abs diff: 1.1920929e-07

import tensorflow as tf
import numpy as np

import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

np.random.seed(2019)
EPS = 1e-3
IS_CENTER = True
IS_SCALE = True

print('tf.__version__', tf.__version__)

def get_data_batch():
    bs = 1
    h = 3
    w = 3
    c = 4

    x_np = np.random.rand(bs, h, w, c)
    x_np = x_np.astype(np.float32)
    print('x_np.shape', x_np.shape)
    return x_np


def run_batch_norm(x_np):
    print('=' * 60)

    print('np.sum(x_np)', np.sum(x_np))

    with tf.Session() as sess:
        x_tf = tf.convert_to_tensor(x_np)

        z_tf = tf.layers.batch_normalization(x_tf,
                        axis=-1,
                        momentum=0.99,
                        epsilon=EPS,
                        center=IS_CENTER,
                        scale=IS_SCALE,
                        training=True,
                        trainable=True,
                        name=None,
                        reuse=None,
                        renorm=False,
                        renorm_clipping=None,
                        renorm_momentum=0.99,
                        fused=None,
                        virtual_batch_size=None,
                        adjustment=None)

        sess.run(tf.global_variables_initializer())
        z_np = sess.run(fetches=[z_tf], feed_dict={x_tf: x_np})[0]
        print('z_np.shape', z_np.shape)
        print('z_np', z_np)

        return z_np


def run_instsance_norm(x_np):
    print('=' * 60)

    print('np.sum(x_np)', np.sum(x_np))

    with tf.Session() as sess:
        x_tf = tf.convert_to_tensor(x_np)

        z_tf = tf.contrib.layers.instance_norm(x_tf,
                                               center=IS_CENTER,
                                               scale=IS_SCALE,
                                               epsilon=EPS,
                                               activation_fn=None,
                                               param_initializers=None,
                                               reuse=None,
                                               variables_collections=None,
                                               outputs_collections=None,
                                               trainable=True,
                                               data_format="NHWC",
                                               scope="instance_norm_scope")

        sess.run(tf.global_variables_initializer())
        z_np = sess.run(fetches=[z_tf], feed_dict={x_tf: x_np})[0]
        print('z_np.shape', z_np.shape)
        print('z_np', z_np)

        return z_np


def run_test():
    x_np = get_data_batch()

    z_np_1 = run_instsance_norm(x_np)

    z_np_2 = run_batch_norm(x_np)

    print('max abs diff:', np.max(np.abs(z_np_1-z_np_2)))


run_test()

@GreenAppers
Copy link
Author

Ahh-hah! Good to know, @mrgloom. Thanks.

You can therefore convert the models using get_weights and set_weights.

I don't have time now to update this PR with code to do the conversion. But a nice utility to convert a previously trained pix2pix model to a TFLite compatible one should be possible by loading the previously trained model, and a new model from this PR, and transferring the weights with get_weights and set_weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants