tensorflow / 2.9.1 / tpu / experimental / embedding / tpuembeddingforserving.html /

tf.tpu.experimental.embedding.TPUEmbeddingForServing

The TPUEmbedding mid level API running on CPU for serving.

Note: This class is intended to be used for embedding tables that are trained on TPU and to be served on CPU. Therefore the class should be only initialized under non-TPU strategy. Otherwise an error will be raised.

You can first train your model using the TPUEmbedding class and save the checkpoint. Then use this class to restore the checkpoint to do serving.

First train a model and save the checkpoint.

model = model_fn(...)
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
      feature_config=feature_config,
      optimizer=tf.tpu.experimental.embedding.SGD(0.1))

# Your custom training code.

checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.save(...)

Then restore the checkpoint and do serving.

# Restore the model on CPU.
model = model_fn(...)
embedding = tf.tpu.experimental.embedding.TPUEmbeddingForServing(
      feature_config=feature_config,
      optimizer=tf.tpu.experimental.embedding.SGD(0.1))

checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.restore(...)

result = embedding(...)
table = embedding.embedding_table
Note: This class can also be used to do embedding training on CPU. But it requires the conversion between keras optimizer and embedding optimizers so that the slot variables can stay consistent between them.
Args
feature_config A nested structure of tf.tpu.experimental.embedding.FeatureConfig configs.
optimizer An instance of one of tf.tpu.experimental.embedding.SGD, tf.tpu.experimental.embedding.Adagrad or tf.tpu.experimental.embedding.Adam. When not created under TPUStrategy may be set to None to avoid the creation of the optimizer slot variables, useful for optimizing memory consumption when exporting the model for serving where slot variables aren't needed.
Raises
RuntimeError If created under TPUStrategy.
Attributes
embedding_tables Returns a dict of embedding tables, keyed by TableConfig.

Methods

build

View source

Create variables and slots variables for TPU embeddings.

embedding_lookup

View source

Apply standard lookup ops on CPU.

Args
features A nested structure of tf.Tensors, tf.SparseTensors or tf.RaggedTensors, with the same structure as feature_config. Inputs will be downcast to tf.int32. Only one type out of tf.SparseTensor or tf.RaggedTensor is supported per call.
weights If not None, a nested structure of tf.Tensors, tf.SparseTensors or tf.RaggedTensors, matching the above, except that the tensors should be of float type (and they will be downcast to tf.float32). For tf.SparseTensors we assume the indices are the same for the parallel entries from features and similarly for tf.RaggedTensors we assume the row_splits are the same.
Returns
A nested structure of Tensors with the same structure as input features.

__call__

View source

Call the mid level api to do embedding lookup.

© 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/tpu/experimental/embedding/TPUEmbeddingForServing