-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add options to train and export TFLite compatible models #157
base: master
Are you sure you want to change the base?
Conversation
I converted my model and now it only outputs empty black files. Any tips? |
You would have to retrain your model using the code from this PR. The issue is that the original implementation uses batch normalization with batch_size=1. This is a degenerate case, according to my understanding of the definitions, where batch normalization becomes instance normalization. However I can't speak to how Tensorflow implements these operations. At any rate, TFLite won't accept batch_normalization with batch_size=1 and requires instance_norm instead. This requires retraining. A smart conversion tool could probably keep all the weights, but change some ID field in the Tensorflow protobuf representing batch_norm to instance_norm. Here are the command lines I used to train and export a TFLite model:
|
Do we need to add update_ops to train_op also? From documentation:
|
Yes, tf.layers.batch_normalization is the same as tf.contrib.layers.instance_norm when batch_size=1 Test: max abs diff: 1.1920929e-07
|
Ahh-hah! Good to know, @mrgloom. Thanks. You can therefore convert the models using I don't have time now to update this PR with code to do the conversion. But a nice utility to convert a previously trained pix2pix model to a TFLite compatible one should be possible by loading the previously trained model, and a new model from this PR, and transferring the weights with |
These are the changes needed to get pix2pix-tensorflow running on mobile.
tf.layers.batch_normalization() on TFLite requires training=False, and that the model was trained with training=True and batch_size > 1. Also TFLite has no tf.tanh(), tf.image.convert_image_dtype(), or others.
Updating the batch_normalization Tensorflow variables for training=False (which aren't trainable vars) requires the UPDATE_OPS dependencies.
I have to re-train the model using tf.contrib.layers.instance_norm() instead of tf.layers.batch_normalization() with batch_size=1. It seems to work good. Is it the exact same thing?