diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py index 1868e8c8e..b76eacf14 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py @@ -311,17 +311,23 @@ class FieldWiseEmbedding(BasicEmbedding): ```python nslots = 3 @tf.function - def map_slot_fn(feature_id): + def feature_to_slot(feature_id): field_id = tf.math.mod(feature_id, nslots) return field_id ids = tf.constant([[23, 12, 0], [9, 13, 10]], dtype=tf.int64) - embedding = de.layers.FieldWiseEmbedding(1, nslots, map_slot_fn) + embedding = de.layers.FieldWiseEmbedding(2, + nslots, + slot_map_fn=feature_to_slot, + initializer=tf.keras.initializer.Zeros()) + + out = embedding(ids) + # [[[0., 0.], [0., 0.], [0., 1.]] + # [[0., 0.], [0., 0.], [0., 1.]]] prepared_keys = tf.range(0, 100, dtype=tf.int64) prepared_values = tf.ones((100, 2), dtype=tf.float32) embedding.params.upsert(prepared_keys, prepared_values) - out = embedding(ids) # [[2., 2.], [0., 0.], [1., 1.]] # [[1., 1.], [2., 2.], [0., 0.]]