分步指南:微调 MedGemma 进行乳腺肿瘤分类
Source: Dev.to
介绍
人工智能(AI)正在彻底改变医疗保健,但如何将一个强大的通用 AI 模型教会病理学家的专业技能呢?本指南将演示第一步:对 Gemma 3 变体 MedGemma 进行微调,以对乳腺癌组织学图像进行分类。
目标
将显微镜下的乳腺组织图像分类为八个类别之一(四个良性,四个恶性)。我们将使用 MedGemma 模型——Google 开源的医学模型系列——以及它的视觉组件 MedSigLIP 和语言组件 google/medgemma-4b-it。
数据集
我们使用 Breast Cancer Histopathological Image Classification (BreakHis) 数据集,这是一个公开的包含来自 82 位患者、四种放大倍率(40×、100×、200×、400×)的数千张显微镜图像的集合。该数据集可用于非商业研究;详见原始论文:
F. A. Spanhol, L. S. Oliveira, C. Petitjean, and L. Heudel, A dataset for breast cancer histopathological image classification.
在本演示中,我们聚焦于 Fold 1 和 100× 放大倍率,以保持训练时间可控。
硬件
对一个 40 亿参数的模型进行微调需要强大的 GPU。在 notebook 中我们使用了 NVIDIA A100(40 GB VRAM),运行在 Vertex AI Workbench 上,它提供了加速现代数据格式的 Tensor Cores。
Float16 与 Bfloat16
首次尝试使用 float16(FP16)以节省显存,但训练过程中因数值溢出(FP16 的最大可表示值约为 65 504)而出现 NaN。切换到 bfloat16 (BF16)——保留了 32 位浮点数的范围,只牺牲了一些精度——可以防止溢出并使训练稳定。
# The simple, stable solution
model_kwargs = dict(
torch_dtype=torch.bfloat16, # Use bfloat16 for its wide numerical range
device_map="auto",
attn_implementation="sdpa",
)
model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)
经验教训: 对大模型进行微调时,优先使用 bfloat16 以避免 NaN 相关的问题。
步骤笔记本概览
1. 安装所需软件包
!pip install --upgrade --quiet transformers datasets evaluate peft trl scikit-learn
import os
import re
import torch
import gc
from datasets import load_dataset, ClassLabel
from peft import LoraConfig, PeftModel
from transformers import AutoModelForImageTextToText, AutoProcessor
from trl import SFTTrainer, SFTConfig
import evaluate
2. 使用 Hugging Face 进行身份验证
安全提示: 切勿在 notebook 中硬编码密钥(API key、token)。在生产环境中请使用 Google Cloud Secret Manager。进行快速实验时,可使用交互式登录小部件。
from huggingface_hub import notebook_login
notebook_login()
3. 下载 BreakHis 数据集
!pip install -q kagglehub
import kagglehub
import pandas as pd
from PIL import Image
from datasets import Dataset, Image as HFImage, Features, ClassLabel
# Download dataset metadata
path = kagglehub.dataset_download("ambarish/breakhis")
print("Path to dataset files:", path)
folds = pd.read_csv(f"{path}/Folds.csv")
# Filter for 100X magnification from Fold 1
folds_100x = folds[(folds["mag"] == 100) & (folds["fold"] == 1)]
# Train / test splits
folds_100x_test = folds_100x[folds_100x["grp"] == "test"]
folds_100x_train = folds_100x[folds_100x["grp"] == "train"]
# Base path for images
BASE_PATH = "/home/jupyter/.cache/kagglehub/datasets/ambarish/breakhis/v"
4. 为 Hugging Face 准备数据集
(实现细节省略;请参考 notebook,将图像路径映射为 HFImage 对象,创建 ClassLabel 分类体系,并划分为 Dataset 对象。)
5. 配置 LoRA 微调
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
6. 初始化 Trainer
trainer = SFTTrainer(
model=model,
tokenizer=processor,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=lora_config,
args=SFTConfig(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=5e-5,
num_train_epochs=3,
fp16=False, # Using bfloat16 instead
bf16=True,
logging_steps=10,
evaluation_strategy="steps",
eval_steps=50,
save_steps=100,
output_dir="./medgemma_finetuned",
),
)
7. 训练与评估
trainer.train()
trainer.evaluate()
后续步骤
原型完成后,可将工作流迁移至可扩展的生产环境,例如使用 Cloud Run jobs(详见后续文章)。
免责声明:本指南仅供信息和教育用途,不能替代专业的医学建议、诊断或治疗。