A step-by-step guide to fine-tuning MedGemma for breast tumor classification
Source: Dev.to
Introduction
Artificial intelligence (AI) is revolutionizing healthcare, but how do you take a powerful, general‑purpose AI model and teach it the specialized skills of a pathologist? This guide walks through the first step: fine‑tuning the Gemma 3 variant MedGemma to classify breast cancer histopathology images.
Goal
Classify microscope images of breast tissue into one of eight categories (four benign, four malignant). We’ll use the MedGemma model—Google’s open family of medical models—combined with its vision component MedSigLIP and language component google/medgemma-4b-it.
Dataset
We use the Breast Cancer Histopathological Image Classification (BreakHis) dataset, a public collection of thousands of microscope images from 82 patients at four magnifications (40×, 100×, 200×, 400×). The dataset is available for non‑commercial research; see the original paper:
F. A. Spanhol, L. S. Oliveira, C. Petitjean, and L. Heudel, A dataset for breast cancer histopathological image classification.
For this demonstration we focus on Fold 1 and 100× magnification to keep training time manageable.
Hardware
Fine‑tuning a 4‑billion‑parameter model requires a capable GPU. In the notebook we used an NVIDIA A100 (40 GB VRAM) on Vertex AI Workbench, which provides Tensor Cores that accelerate modern data formats.
Float16 vs. Bfloat16
The first attempt used float16 (FP16) to save memory, but training collapsed into NaNs due to numerical overflow (FP16’s max representable value ≈ 65 504). Switching to bfloat16 (BF16)—which retains the 32‑bit float range while sacrificing some precision—prevents overflow and stabilizes training.
# The simple, stable solution
model_kwargs = dict(
torch_dtype=torch.bfloat16, # Use bfloat16 for its wide numerical range
device_map="auto",
attn_implementation="sdpa",
)
model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)
Lesson: Prefer bfloat16 for fine‑tuning large models to avoid NaN‑related issues.
Step‑by‑Step Notebook Overview
1. Install Required Packages
!pip install --upgrade --quiet transformers datasets evaluate peft trl scikit-learn
import os
import re
import torch
import gc
from datasets import load_dataset, ClassLabel
from peft import LoraConfig, PeftModel
from transformers import AutoModelForImageTextToText, AutoProcessor
from trl import SFTTrainer, SFTConfig
import evaluate
2. Authenticate with Hugging Face
Security note: Never hard‑code secrets (API keys, tokens) in notebooks. Use Google Cloud Secret Manager in production. For quick experiments, the interactive login widget can be used.
from huggingface_hub import notebook_login
notebook_login()
3. Download the BreakHis Dataset
!pip install -q kagglehub
import kagglehub
import pandas as pd
from PIL import Image
from datasets import Dataset, Image as HFImage, Features, ClassLabel
# Download dataset metadata
path = kagglehub.dataset_download("ambarish/breakhis")
print("Path to dataset files:", path)
folds = pd.read_csv(f"{path}/Folds.csv")
# Filter for 100X magnification from Fold 1
folds_100x = folds[(folds["mag"] == 100) & (folds["fold"] == 1)]
# Train / test splits
folds_100x_test = folds_100x[folds_100x["grp"] == "test"]
folds_100x_train = folds_100x[folds_100x["grp"] == "train"]
# Base path for images
BASE_PATH = "/home/jupyter/.cache/kagglehub/datasets/ambarish/breakhis/v"
4. Prepare the Dataset for Hugging Face
(Implementation details omitted for brevity; follow the notebook to map image paths to HFImage objects, create ClassLabel taxonomy, and split into Dataset objects.)
5. Configure LoRA Fine‑Tuning
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
6. Initialize the Trainer
trainer = SFTTrainer(
model=model,
tokenizer=processor,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=lora_config,
args=SFTConfig(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=5e-5,
num_train_epochs=3,
fp16=False, # Using bfloat16 instead
bf16=True,
logging_steps=10,
evaluation_strategy="steps",
eval_steps=50,
save_steps=100,
output_dir="./medgemma_finetuned",
),
)
7. Train and Evaluate
trainer.train()
trainer.evaluate()
Next Steps
After prototyping, the workflow can be migrated to a scalable, production‑ready environment using Cloud Run jobs (see the upcoming post for details).
Disclaimer: This guide is for informational and educational purposes only and is not a substitute for professional medical advice, diagnosis, or treatment.