Understanding Overfitting in Neural Networks (TensorFlow- CNN)
Source: Dev.to
Overfitting is a fundamental challenge when developing neural networks. A model that performs extremely well on the training dataset may fail to generalize to unseen data, leading to poor real‑world performance. This article presents a structured investigation of overfitting using the Fashion‑MNIST dataset and evaluates several mitigation strategies, including Dropout, L2 regularisation, and Early Stopping.
Fashion‑MNIST dataset
- 60,000 training images
- 10,000 test images
- 28 × 28 grayscale format
- 10 output classes
A significantly smaller subset of the training data is intentionally used to make overfitting behaviour more visible.
CNN Architecture
def create_cnn_model(l2_lambda=0.0, dropout_rate=0.0):
model = keras.Sequential([
layers.Conv2D(32, (3, 3), activation='relu',
kernel_regularizer=regularizers.l2(l2_lambda)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu',
kernel_regularizer=regularizers.l2(l2_lambda)),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dense(64, activation='relu',
kernel_regularizer=regularizers.l2(l2_lambda)),
layers.Dropout(dropout_rate),
layers.Dense(10, activation='softmax')
])
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)
return model
All experiments share this architecture, with optional L2 regularisation and Dropout.
Utility for Plotting Training History
def plot_history(history, title_prefix=""):
hist = history.history
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(hist["loss"], label="Train Loss")
plt.plot(hist["val_loss"], label="Val Loss")
plt.title(f"{title_prefix} Loss")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(hist["accuracy"], label="Train Accuracy")
plt.plot(hist["val_accuracy"], label="Val Accuracy")
plt.title(f"{title_prefix} Accuracy")
plt.legend()
plt.tight_layout()
plt.show()
Baseline Model (No Regularisation)
baseline_model = create_cnn_model(l2_lambda=0.0, dropout_rate=0.0)
history_baseline = baseline_model.fit(
x_train_small, y_train_small,
validation_split=0.2,
epochs=20
)
plot_history(history_baseline, title_prefix="Baseline (no regularisation)")
Observations
- Training accuracy continues to increase steadily.
- Validation accuracy peaks early and then declines.
- Training loss decreases, while validation loss increases.
These patterns indicate clear overfitting.
Dropout (Rate = 0.5)
dropout_model = create_cnn_model(dropout_rate=0.5)
history_dropout = dropout_model.fit(
x_train_small, y_train_small,
validation_split=0.2,
epochs=20
)
plot_history(history_dropout, title_prefix="Dropout (0.5)")
Observations
- Training accuracy rises more slowly (expected due to Dropout).
- Validation accuracy tracks the training curve more closely.
- Divergence between training and validation loss is significantly reduced.
Dropout is highly effective in this experiment, producing noticeably improved generalisation.
L2 Regularisation (λ = 0.001)
l2_model = create_cnn_model(l2_lambda=0.001)
history_l2 = l2_model.fit(
x_train_small, y_train_small,
validation_split=0.2,
epochs=20
)
plot_history(history_l2, title_prefix="L2 Regularisation")
Observations
- Training loss is noticeably higher due to weight penalisation.
- Validation loss trends are more stable compared to the baseline.
- Validation accuracy improves moderately.
L2 regularisation smooths learning dynamics and alleviates overfitting, though its impact is milder than Dropout in this setup.
Early Stopping
earlystop_model = create_cnn_model()
early_stop = keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=3,
restore_best_weights=True
)
history_early = earlystop_model.fit(
x_train_small, y_train_small,
validation_split=0.2,
epochs=20,
callbacks=[early_stop]
)
plot_history(history_early, title_prefix="Early Stopping")
Observations
- Training terminates after validation loss stops improving.
- Avoids the late‑epoch overfitting seen in the baseline.
- Produces one of the cleanest validation curves among all models.
Early stopping is a simple and effective generalisation technique.
Model Conversion (Optional)
converter = tf.lite.TFLiteConverter.from_keras_model(baseline_model)
tflite_model = converter.convert()
print("Quantised model size (bytes):", len(tflite_model))
This step demonstrates model size reduction for deployment purposes; it is not a regularisation strategy.
Summary of Results
- Baseline: Clear overfitting.
- Dropout: Largest improvement in validation behaviour.
- L2 Regularisation: Helps stabilise training dynamics.
- Early Stopping: Prevents late‑epoch divergence and improves generalisation.
- Combination (Dropout + Early Stopping): Yields the most robust performance on the reduced Fashion‑MNIST dataset.