分步指南:微调 MedGemma 进行乳腺肿瘤分类

发布: (2025年12月3日 GMT+8 03:45)
4 min read
原文: Dev.to

Source: Dev.to

介绍

人工智能(AI)正在彻底改变医疗保健,但如何将一个强大的通用 AI 模型教会病理学家的专业技能呢?本指南将演示第一步:对 Gemma 3 变体 MedGemma 进行微调,以对乳腺癌组织学图像进行分类。

目标

将显微镜下的乳腺组织图像分类为八个类别之一(四个良性,四个恶性)。我们将使用 MedGemma 模型——Google 开源的医学模型系列——以及它的视觉组件 MedSigLIP 和语言组件 google/medgemma-4b-it

数据集

我们使用 Breast Cancer Histopathological Image Classification (BreakHis) 数据集,这是一个公开的包含来自 82 位患者、四种放大倍率(40×、100×、200×、400×)的数千张显微镜图像的集合。该数据集可用于非商业研究;详见原始论文:

F. A. Spanhol, L. S. Oliveira, C. Petitjean, and L. Heudel, A dataset for breast cancer histopathological image classification.

在本演示中,我们聚焦于 Fold 1100× 放大倍率,以保持训练时间可控。

硬件

对一个 40 亿参数的模型进行微调需要强大的 GPU。在 notebook 中我们使用了 NVIDIA A100(40 GB VRAM),运行在 Vertex AI Workbench 上,它提供了加速现代数据格式的 Tensor Cores。

Float16 与 Bfloat16

首次尝试使用 float16(FP16)以节省显存,但训练过程中因数值溢出(FP16 的最大可表示值约为 65 504)而出现 NaN。切换到 bfloat16 (BF16)——保留了 32 位浮点数的范围,只牺牲了一些精度——可以防止溢出并使训练稳定。

# The simple, stable solution
model_kwargs = dict(
    torch_dtype=torch.bfloat16,  # Use bfloat16 for its wide numerical range
    device_map="auto",
    attn_implementation="sdpa",
)

model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)

经验教训: 对大模型进行微调时,优先使用 bfloat16 以避免 NaN 相关的问题。

步骤笔记本概览

1. 安装所需软件包

!pip install --upgrade --quiet transformers datasets evaluate peft trl scikit-learn
import os
import re
import torch
import gc
from datasets import load_dataset, ClassLabel
from peft import LoraConfig, PeftModel
from transformers import AutoModelForImageTextToText, AutoProcessor
from trl import SFTTrainer, SFTConfig
import evaluate

2. 使用 Hugging Face 进行身份验证

安全提示: 切勿在 notebook 中硬编码密钥(API key、token)。在生产环境中请使用 Google Cloud Secret Manager。进行快速实验时,可使用交互式登录小部件。

from huggingface_hub import notebook_login
notebook_login()

3. 下载 BreakHis 数据集

!pip install -q kagglehub
import kagglehub
import pandas as pd
from PIL import Image
from datasets import Dataset, Image as HFImage, Features, ClassLabel

# Download dataset metadata
path = kagglehub.dataset_download("ambarish/breakhis")
print("Path to dataset files:", path)

folds = pd.read_csv(f"{path}/Folds.csv")

# Filter for 100X magnification from Fold 1
folds_100x = folds[(folds["mag"] == 100) & (folds["fold"] == 1)]

# Train / test splits
folds_100x_test = folds_100x[folds_100x["grp"] == "test"]
folds_100x_train = folds_100x[folds_100x["grp"] == "train"]

# Base path for images
BASE_PATH = "/home/jupyter/.cache/kagglehub/datasets/ambarish/breakhis/v"

4. 为 Hugging Face 准备数据集

(实现细节省略;请参考 notebook,将图像路径映射为 HFImage 对象,创建 ClassLabel 分类体系,并划分为 Dataset 对象。)

5. 配置 LoRA 微调

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

6. 初始化 Trainer

trainer = SFTTrainer(
    model=model,
    tokenizer=processor,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=lora_config,
    args=SFTConfig(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=5e-5,
        num_train_epochs=3,
        fp16=False,          # Using bfloat16 instead
        bf16=True,
        logging_steps=10,
        evaluation_strategy="steps",
        eval_steps=50,
        save_steps=100,
        output_dir="./medgemma_finetuned",
    ),
)

7. 训练与评估

trainer.train()
trainer.evaluate()

后续步骤

原型完成后,可将工作流迁移至可扩展的生产环境,例如使用 Cloud Run jobs(详见后续文章)。

免责声明:本指南仅供信息和教育用途,不能替代专业的医学建议、诊断或治疗。

Back to Blog

相关文章

阅读更多 »