Easy FunctionGemma finetuning with Tunix on Google TPUs

Published: (March 3, 2026 at 10:55 AM EST)
4 min read

Source: Google Developers Blog

FEB. 3, 2026

FunctionGemma is a powerful small language model that enables developers to ship fast and cost‑effective agents that can translate natural language into actionable API calls, especially on edge devices. In the previous A Guide to Fine‑Tuning FunctionGemma blog, our colleague shared some best practices for fine‑tuning FunctionGemma using the Hugging Face TRL library on GPUs. In this post we explore a different path by using Google Tunix to perform the fine‑tuning on TPUs. You can find the complete notebook here.

Google stack

Tunix is a lightweight library implemented in JAX and designed to streamline the post‑training of large language models (LLMs). It is part of the extended JAX AI Stack. Tunix supports a wide range of modern LLM post‑training techniques such as supervised fine‑tuning, parameter‑efficient fine‑tuning, preference tuning, reinforcement learning, and model distillation. It works with the latest open models like Gemma, Qwen, and LLaMA, and is built to run efficiently on large‑scale hardware accelerators.

In this tutorial we use LoRA to do supervised fine‑tuning on FunctionGemma and run everything on a free‑tier Colab TPU v5e‑1. We use the same Mobile Action dataset as in the previous fine‑tuning tutorial.


1️⃣ Download the model and dataset

MODEL_ID = "google/functiongemma-270m-it"
DATASET_ID = "google/mobile-actions"

local_model_path = snapshot_download(
    repo_id=MODEL_ID,
    ignore_patterns=["*.pth"]
)
data_file = hf_hub_download(
    repo_id=DATASET_ID,
    filename="dataset.jsonl",
    repo_type="dataset"
)

2️⃣ Create a (non‑sharded) mesh for the free‑tier TPU

NUM_TPUS = len(jax.devices())
MESH = (
    [(1, NUM_TPUS), ("fsdp", "tp")] if NUM_TPUS > 1 else [(1, 1), ("fsdp", "tp")]
)
mesh = jax.make_mesh(
    *MESH,
    axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0])
)

3️⃣ Load the model and apply LoRA adapters

with mesh:
    base_model = params_safetensors_lib.create_model_from_safe_tensors(
        local_model_path, model_config, mesh
    )
    lora_provider = qwix.LoraProvider(
        module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj",
        rank=LORA_RANK,
        alpha=LORA_ALPHA,
    )
    model_input = base_model.get_model_input()
    model = qwix.apply_lora_to_model(
        base_model,
        lora_provider,
        rngs=nnx.Rngs(0),
        **model_input
    )
    state = nnx.state(model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(model, sharded_state)

4️⃣ Custom dataset for completion‑only loss

class CustomDataset:
    def __init__(self, data, tokenizer, max_length=1024):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        for item in self.data:
            template_inputs = json.loads(item["text"])

            # Full prompt + completion (no generation token)
            prompt_and_completion = self.tokenizer.apply_chat_template(
                template_inputs["messages"],
                tools=template_inputs["tools"],
                tokenize=False,
                add_generation_prompt=False,
            )
            # Prompt only (adds generation token)
            prompt_only = self.tokenizer.apply_chat_template(
                template_inputs["messages"][:-1],
                tools=template_inputs["tools"],
                tokenize=False,
                add_generation_prompt=True,
            )

            tokenized_full = self.tokenizer(
                prompt_and_completion, add_special_tokens=False
            )
            tokenized_prompt = self.tokenizer(
                prompt_only, add_special_tokens=False
            )

            full_ids = tokenized_full["input_ids"]
            prompt_len = len(tokenized_prompt["input_ids"])

            # Truncate if necessary
            if len(full_ids) > self.max_length:
                full_ids = full_ids[: self.max_length]

            # Pad to max_length
            input_tokens = np.full(
                (self.max_length,),
                self.tokenizer.pad_token_id,
                dtype=np.int32,
            )
            input_tokens[: len(full_ids)] = full_ids

            # Build mask (1 for tokens that should contribute to loss)
            input_mask = np.zeros((self.max_length,), dtype=np.int32)
            if len(full_ids) > prompt_len:
                mask_end = min(len(full_ids), self.max_length)
                input_mask[prompt_len:mask_end] = 1

            yield peft_trainer.TrainingInput(
                input_tokens=jnp.array(input_tokens, dtype=jnp.int32),
                input_mask=jnp.array(input_mask, dtype=jnp.int32),
            )

5️⃣ Data generators

def data_generator(split_data, batch_size):
    dataset_obj = CustomDataset(split_data, tokenizer, MAX_LENGTH)
    # The rest of the generator implementation would go here…

(The snippet above is intentionally truncated to match the original source.)

ch_tokens, batch_masks = [], []
for item in dataset_obj:
    batch_tokens.append(item.input_tokens)
    batch_masks.append(item.input_mask)
    if len(batch_tokens) == batch_size:
        yield peft_trainer.TrainingInput(
            input_tokens=jnp.array(np.stack(batch_tokens)),
            input_mask=jnp.array(np.stack(batch_masks)),
        )
        batch_tokens, batch_masks = [], []

print("Preparing training data...")
train_batches = list(data_generator(train_data, BATCH_SIZE))
val_batches = list(data_generator(val_data_for_loss, BATCH_SIZE))

Kick‑off the fine‑tuning

print("Starting Training...")
max_steps = len(train_batches) * NUM_EPOCHS
lr_schedule = optax.cosine_decay_schedule(
    init_value=LEARNING_RATE, decay_steps=max_steps
)

metrics_logging_options = metrics_logger.MetricsLoggerOptions(
    log_dir=os.path.join(OUTPUT_DIR, "logs"), flush_every_n_steps=10
)

training_config = peft_trainer.TrainingConfig(
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=max_steps,
    checkpoint_root_directory=os.path.join(OUTPUT_DIR, "ckpts"),
    metrics_logging_options=metrics_logging_options,
)

trainer = (
    peft_trainer.PeftTrainer(model, optax.adamw(lr_schedule), training_config)
    .with_gen_model_input_fn(gen_model_input_fn)
)

with mesh:
    trainer.train(train_batches, val_batches)

print("Training Complete.")

The training takes a few minutes, and Tunix achieves a high TPU utilization rate:

TPU utilization

After one epoch, accuracy improves significantly, demonstrating Tunix’s ability to deliver qualitative gains with minimal overhead:

Accuracy before vs. after fine‑tuning

Merge LoRA adapters & export the fine‑tuned model

merged_output_dir = os.path.join(OUTPUT_DIR, "merged")
print(f"Saving merged LoRA model to {merged_output_dir}")

gemma_params.save_lora_merged_model_as_safetensors(
    local_model_path=local_model_path,
    output_dir=merged_output_dir,
    lora_model=model,
    rank=LORA_RANK,
    alpha=LORA_ALPHA,
)

print("Model Exported Successfully.")

That’s the complete workflow for fine‑tuning FunctionGemma with Tunix. As shown, Tunix is straightforward to use and efficiently leverages Google TPUs. While this example covers supervised fine‑tuning—the simplest approach—Tunix also supports more advanced techniques such as reinforcement learning and other agentic training capabilities that are currently under development.

Conclusion

Tunix bridges the gap between research prototypes and production‑ready systems. Its modularity, JAX‑native speed, and breadth of supported algorithms make it an essential tool for any developer looking to polish their LLMs for specific tasks.

Please check out the Tunix documentation to learn more and follow the Tunix repository for updates.


Previous | Next (replace # with the appropriate URLs)

0 views
Back to Blog

Related posts

Read more »

Get ready for Google I/O 2026

Google I/O returns May 19–20 Google I/O is back! Join us online as we share our latest AI breakthroughs and updates in products across the company, from Gemini...