Keras Applications
import tensorflow as tf
base = tf.keras.applications.ResNet50(
include_top=True,
weights="imagenet",
input_shape=(224, 224, 3),
)
base.trainable = False
# Replace top for num_classes:
# x = base.output
# x = tf.keras.layers.Dense(num_classes, activation="softmax")(x)
# model = tf.keras.Model(base.input, x)
Use the matching preprocess_input from the same module as the architecture (ResNet vs EfficientNet differ).
tf.data from file paths
def load_decode(path, label):
img = tf.io.read_file(path)
img = tf.image.decode_jpeg(img, channels=3)
img = tf.image.resize(img, [224, 224])
return img, label
paths_ds = tf.data.Dataset.from_tensor_slices((paths, labels))
ds = paths_ds.map(load_decode, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(32).prefetch(tf.data.AUTOTUNE)
KerasCV (optional)
pip install keras-cv adds detection models, augmentations, and COCO-style metrics aligned with Keras 3 / multi-backend workflows—check the version matrix against your TensorFlow install.
Takeaways
- Keep training and serving preprocessing identical when possible (layers or SavedModel signatures).
- Mixed precision:
tf.keras.mixed_precision.Policy("mixed_float16")on supported GPUs. - Compare with PyTorch for team skill fit and deployment targets (TF Lite, Edge TPU, JAX).