-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ModelCheckPoint problem #127
Comments
Hi @redaelhail, Yes, saving tensorflow objects is not easy... A suggestion is to only save the weights of the netwoks. Here is a little example that should work: import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from adapt.feature_based import DANN
np.random.seed(0)
Xt = np.random.randn(100, 2)
Xs = np.random.randn(100, 2)
ys = np.random.choice(2, 100)
def encoder():
mod = tf.keras.Sequential()
mod.add(tf.keras.layers.Dense(10, activation="relu", input_shape=(2,)))
return mod
def task():
mod = tf.keras.Sequential()
mod.add(tf.keras.layers.Dense(1, activation="sigmoid"))
return mod
def discriminitor():
mod = tf.keras.Sequential()
mod.add(tf.keras.layers.Dense(1, activation="sigmoid"))
return mod
chk = ModelCheckpoint('Model.hdf5',
monitor="loss",
verbose=1,
save_best_only=True,
save_weights_only=True,
mode='min',
save_freq=1)
# define callbacks
callbacks_list = [chk]
# Build model
model = DANN(encoder=encoder(), task=task(), discriminator=discriminitor(),
Xt=Xt, lambda_= 0.1, loss="bce", metrics=["acc"], random_state=0)
# start training
model_log = model.fit(Xs, ys,epochs = 2, callbacks=callbacks_list, verbose=1)
# Load saved weights
model.load_weights('Model.hdf5') If you want to load the weights in a new model, you still have to call the fit function to instantiate the variables: new_model = DANN(encoder=encoder(), task=task(), discriminator=discriminitor(),
Xt=Xt, lambda_= 0.1, metrics=["acc"], random_state=0)
new_model.fit(Xs, ys,epochs = 0)
new_model.load_weights('Model.hdf5') You can set PS: I see that, in your code, you use the accuracy metric. I guess you are solving a classification problem? Be careful with the default parameters of DANN. The loss by default is "mse", but I think you need to use the binary cross-entropy ("bce") or the cross-entropy for multiclass? You can also change the default optimizer in the DANN arguments. |
yes, that worked, i can now now save the best run. Thank you @antoinedemathelin. For your remark, I am using the cross-entropy since i am dealing with multiclass classification; it was clear in the documentation that is should be written within the DANN class. Thank you |
Hello,
Thank you for your work.
I am doing domain adaptation with DANN. I would like to save the best model using model checkpoint based on the loss value of the task network:
During the trainning, i keep receiving this warning:
This is the training code:
The text was updated successfully, but these errors were encountered: