유방 종양 분류를 위한 MedGemma 파인튜닝 단계별 가이드
Source: Dev.to
Introduction
인공지능(AI)은 의료 분야에 혁신을 일으키고 있지만, 강력한 범용 AI 모델을 어떻게 하면 병리학자의 전문 기술을 가르칠 수 있을까요? 이 가이드는 첫 번째 단계인 Gemma 3 변형 MedGemma를 미세조정하여 유방암 조직학 이미지 분류를 수행하는 방법을 안내합니다.
Goal
현미경으로 촬영한 유방 조직 이미지를 8가지 카테고리(양성 4개, 악성 4개) 중 하나로 분류합니다. 우리는 MedGemma 모델—구글의 오픈 의료 모델군—과 그 비전 컴포넌트 MedSigLIP, 언어 컴포넌트 google/medgemma-4b-it를 함께 사용할 것입니다.
Dataset
우리는 Breast Cancer Histopathological Image Classification (BreakHis) 데이터셋을 사용합니다. 이 데이터셋은 82명의 환자로부터 4가지 배율(40×, 100×, 200×, 400×)로 촬영된 수천 장의 현미경 이미지가 공개된 컬렉션입니다. 비상업적 연구 목적으로 사용할 수 있으며, 원본 논문은 다음과 같습니다:
F. A. Spanhol, L. S. Oliveira, C. Petitjean, and L. Heudel, A dataset for breast cancer histopathological image classification.
이번 시연에서는 Fold 1과 100× 배율에만 집중하여 학습 시간을 관리 가능한 수준으로 유지합니다.
Hardware
40 GB VRAM을 갖춘 NVIDIA A100을 Vertex AI Workbench에서 사용했습니다. 이 GPU는 최신 데이터 포맷을 가속화하는 Tensor Core를 제공합니다.
Float16 vs. Bfloat16
첫 번째 시도에서는 메모리 절약을 위해 float16(FP16)을 사용했지만, 수치 오버플로우( FP16의 최대 표현값 ≈ 65 504) 때문에 학습이 NaN으로 붕괴되었습니다. **bfloat16 (BF16)**으로 전환하면 32‑bit 부동소수점의 범위를 유지하면서 약간의 정밀도만 손실되므로 오버플로우를 방지하고 학습을 안정화할 수 있습니다.
# The simple, stable solution
model_kwargs = dict(
torch_dtype=torch.bfloat16, # Use bfloat16 for its wide numerical range
device_map="auto",
attn_implementation="sdpa",
)
model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)
Lesson: 대형 모델을 미세조정할 때는 NaN 관련 문제를 피하기 위해 bfloat16을 우선 사용하세요.
Step‑by‑Step Notebook Overview
1. Install Required Packages
!pip install --upgrade --quiet transformers datasets evaluate peft trl scikit-learn
import os
import re
import torch
import gc
from datasets import load_dataset, ClassLabel
from peft import LoraConfig, PeftModel
from transformers import AutoModelForImageTextToText, AutoProcessor
from trl import SFTTrainer, SFTConfig
import evaluate
2. Authenticate with Hugging Face
Security note: Never hard‑code secrets (API keys, tokens) in notebooks. Use Google Cloud Secret Manager in production. For quick experiments, the interactive login widget can be used.
from huggingface_hub import notebook_login
notebook_login()
3. Download the BreakHis Dataset
!pip install -q kagglehub
import kagglehub
import pandas as pd
from PIL import Image
from datasets import Dataset, Image as HFImage, Features, ClassLabel
# Download dataset metadata
path = kagglehub.dataset_download("ambarish/breakhis")
print("Path to dataset files:", path)
folds = pd.read_csv(f"{path}/Folds.csv")
# Filter for 100X magnification from Fold 1
folds_100x = folds[(folds["mag"] == 100) & (folds["fold"] == 1)]
# Train / test splits
folds_100x_test = folds_100x[folds_100x["grp"] == "test"]
folds_100x_train = folds_100x[folds_100x["grp"] == "train"]
# Base path for images
BASE_PATH = "/home/jupyter/.cache/kagglehub/datasets/ambarish/breakhis/v"
4. Prepare the Dataset for Hugging Face
(Implementation details omitted for brevity; follow the notebook to map image paths to HFImage objects, create ClassLabel taxonomy, and split into Dataset objects.)
5. Configure LoRA Fine‑Tuning
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
6. Initialize the Trainer
trainer = SFTTrainer(
model=model,
tokenizer=processor,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=lora_config,
args=SFTConfig(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=5e-5,
num_train_epochs=3,
fp16=False, # Using bfloat16 instead
bf16=True,
logging_steps=10,
evaluation_strategy="steps",
eval_steps=50,
save_steps=100,
output_dir="./medgemma_finetuned",
),
)
7. Train and Evaluate
trainer.train()
trainer.evaluate()
Next Steps
프로토타입을 만든 뒤에는 Cloud Run jobs를 활용해 확장 가능하고 프로덕션에 적합한 환경으로 워크플로를 옮길 수 있습니다(자세한 내용은 추후 포스트를 참고하세요).
Disclaimer: This guide is for informational and educational purposes only and is not a substitute for professional medical advice, diagnosis, or treatment.