Google TPU에서 Tunix로 Easy FunctionGemma 파인튜닝

발행: (2026년 2월 28일 오전 10:53 GMT+9)
6 분 소요

Source: Google Developers Blog

2026년 2월 3일

Source: https://developers.googleblog.com/a-guide-to-fine-tuning-functiongemma/

Fine‑tuning FunctionGemma with Tunix

FunctionGemma는 개발자들이 빠르고 비용 효율적인 에이전트를 제공할 수 있게 해 주는 강력한 소형 언어 모델로, 특히 엣지 디바이스에서 자연어를 실행 가능한 API 호출로 변환할 수 있습니다. 이전 A Guide to Fine‑Tuning FunctionGemma 블로그에서 동료가 GPU에서 Hugging Face TRL 라이브러리를 사용해 FunctionGemma를 파인튜닝하는 모범 사례를 공유했습니다.

이번 포스트에서는 Google Tunix 을 사용해 TPU에서 파인튜닝하는 다른 경로를 탐색합니다. 전체 노트북은 여기 에서 확인할 수 있습니다.

Google stack

Tunix는 JAX로 구현된 경량 라이브러리로, 대형 언어 모델(LLM)의 사후 학습(post‑training)을 간소화하도록 설계되었습니다. 이는 extended JAX AI Stack 의 일부입니다. Tunix는 감독 파인튜닝, Parameter‑Efficient Fine‑Tuning, 프리퍼런스 튜닝, 강화 학습, 모델 증류와 같은 최신 LLM 사후 학습 기법을 폭넓게 지원합니다. Gemma, Qwen, LLaMA와 같은 최신 오픈 모델과 호환되며, 하드웨어 가속기 전반에 걸쳐 효율적으로 확장되도록 구축되었습니다.

이 튜토리얼에서는 LoRA 를 사용해 FunctionGemma를 감독 파인튜닝하고, 모든 과정을 무료 티어 Colab TPU v5e‑1에서 실행합니다. 이전 파인튜닝 튜토리얼 에서 사용한 Mobile Action 데이터셋을 동일하게 사용합니다.

1️⃣ 모델 및 데이터셋 다운로드

MODEL_ID = "google/functiongemma-270m-it"
DATASET_ID = "google/mobile-actions"

local_model_path = snapshot_download(
    repo_id=MODEL_ID,
    ignore_patterns=["*.pth"]
)
data_file = hf_hub_download(
    repo_id=DATASET_ID,
    filename="dataset.jsonl",
    repo_type="dataset"
)

2️⃣ 간단한 메시 생성 (sharding 없음)

NUM_TPUS = len(jax.devices())
MESH = [(1, NUM_TPUS), ("fsdp", "tp")] if NUM_TPUS > 1 else [(1, 1), ("fsdp", "tp")]
mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]))

3️⃣ 모델 로드 및 LoRA 어댑터 적용

with mesh:
    base_model = params_safetensors_lib.create_model_from_safe_tensors(
        local_model_path, model_config, mesh
    )
    lora_provider = qwix.LoraProvider(
        module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj",
        rank=LORA_RANK,
        alpha=LORA_ALPHA,
    )
    model_input = base_model.get_model_input()
    model = qwix.apply_lora_to_model(
        base_model,
        lora_provider,
        rngs=nnx.Rngs(0),
        **model_input,
    )
    state = nnx.state(model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(model, sharded_state)

4️⃣ Completion‑only 손실을 위한 커스텀 데이터셋

class CustomDataset:
    def __init__(self, data, tokenizer, max_length=1024):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        for item in self.data:
            template_inputs = json.loads(item["text"])
            prompt_and_completion = self.tokenizer.apply_chat_template(
                template_inputs["messages"],
                tools=template_inputs["tools"],
                tokenize=False,
                add_generation_prompt=False,
            )
            prompt_only = self.tokenizer.apply_chat_template(

5️⃣ 데이터 생성기

def data_generator(split_data, batch_size):
    dataset_obj = CustomDataset(split_data, tokenizer, MAX_LENGTH)
    bat

(이 스니펫은 원본 소스와 동일하게 여기서 끝납니다.)

batch_tokens, batch_masks = [], []
for item in dataset_obj:
    batch_tokens.append(item.input_tokens)
    batch_masks.append(item.input_mask)
    if len(batch_tokens) == batch_size:
        yield peft_trainer.TrainingInput(
            input_tokens=jnp.array(np.stack(batch_tokens)),
            input_mask=jnp.array(np.stack(batch_masks)),
        )
        batch_tokens, batch_masks = [], []

print("Preparing training data...")
train_batches = list(data_generator(train_data, BATCH_SIZE))
val_batches = list(data_generator(val_data_for_loss, BATCH_SIZE))

파인튜닝 시작

print("Starting Training...")
max_steps = len(train_batches) * NUM_EPOCHS
lr_schedule = optax.cosine_decay_schedule(
    init_value=LEARNING_RATE, decay_steps=max_steps
)

metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir=os.path.join(OUTPUT_DIR, "logs"), flush_every_n_steps=10
)

training_config = peft_trainer.TrainingConfig(
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=max_steps,
    checkpoint_root_directory=os.path.join(OUTPUT_DIR, "ckpts"),
    metrics_logging_options=metrics_logging_options,
)

trainer = (
    peft_trainer.PeftTrainer(model, optax.adamw(lr_schedule), training_config)
    .with_gen_model_input_fn(gen_model_input_fn)
)

with mesh:
    trainer.train(train_batches, val_batches)

print("Training Complete.")

훈련은 몇 분 정도 걸리며, Tunix는 실행 중에 높은 TPU 활용률을 달성할 수 있습니다.

TPU 활용률

한 에포크 훈련 후, 정확도가 크게 상승하는 것을 확인할 수 있으며, 이는 최소한의 오버헤드로도 Tunix가 정성적 개선을 이끌어낼 수 있음을 보여줍니다.

파인튜닝 전 정확도 vs. 파인튜닝 후 정확도

성능에 만족하면 LoRA 어댑터를 병합하고 파인튜닝된 모델을 safetensors 형식으로 다시 내보내어 다운스트림 처리(예: LiteRT 를 이용한 온‑디바이스 배포) 에 활용합니다.

merged_output_dir = os.path.join(OUTPUT_DIR, "merged")
print(f"Saving merged LoRA model to {merged_output_dir}")

gemma_params.save_lora_merged_model_as_safetensors(
    local_model_path=local_model_path,
    output_dir=merged_output_dir,
    lora_model=model,
    rank=LORA_RANK,
    alpha=LORA_ALPHA,
)

print("Model Exported Successfully.")

이것이 Functi 파인튜닝을 위한 전체 워크플로우입니다.

onGemma**와 Tunix. 보시다시피, Tunix는 사용하기 간단하며 Google TPU를 효율적으로 활용할 수 있습니다. 이 예제는 가장 간단한 방법인 감독식 파인튜닝을 다루지만, Tunix는 강화 학습과 같은 보다 고급 기술도 지원하며, 우리는 현재 에이전트 학습 기능을 추가로 개발하고 있습니다.

결론

Tunix는 연구 프로토타입과 프로덕션‑준비 시스템 사이의 격차를 메워줍니다. 모듈성, JAX‑네이티브 속도, 그리고 지원되는 알고리즘의 폭넓음은 특정 작업을 위해 LLM을 다듬고자 하는 모든 개발자에게 필수적인 도구가 됩니다.

더 자세히 알아보려면 Tunix **documentation**를 확인하고, 업데이트를 위해 **Tunix repository**를 팔로우하세요.

0 조회
Back to Blog

관련 글

더 보기 »

Google I/O 2026을 준비하세요

Google I/O가 5월 19~20일에 돌아옵니다! Google I/O가 다시 찾아왔습니다. 최신 AI 혁신과 전사 제품 업데이트를 온라인에서 공유합니다—Gemini부터 시작해서요.