Easy FunctionGemma finetuning with Tunix on Google TPUs
Source: Google Developers Blog
FEB. 3, 2026
Fine‑tuning FunctionGemma with Tunix
FunctionGemma is a powerful small language model that enables developers to ship fast and cost‑effective agents that can translate natural language into actionable API calls, especially on edge devices. In the previous A Guide to Fine‑Tuning FunctionGemma blog, our colleague shared some best practices for fine‑tuning FunctionGemma using the Hugging Face TRL library on GPUs.
In this post we explore a different path by using Google Tunix to perform the fine‑tuning on TPUs. You can find the complete notebook here.

Tunix is a lightweight library implemented in JAX and designed to streamline the post‑training of Large Language Models (LLMs). It is part of the extended JAX AI Stack. Tunix supports a wide range of modern LLM post‑training techniques such as supervised fine‑tuning, Parameter‑Efficient Fine‑Tuning, preference tuning, reinforcement learning, and model distillation. It works with the latest open models like Gemma, Qwen, and LLaMA, and is built to scale efficiently across hardware accelerators.
In this tutorial we use LoRA to do supervised fine‑tuning on FunctionGemma and run everything on a free‑tier Colab TPU v5e‑1. We use the same Mobile Action dataset as in the previous fine‑tuning tutorial.
1️⃣ Download the model and dataset
MODEL_ID = "google/functiongemma-270m-it"
DATASET_ID = "google/mobile-actions"
local_model_path = snapshot_download(
repo_id=MODEL_ID,
ignore_patterns=["*.pth"]
)
data_file = hf_hub_download(
repo_id=DATASET_ID,
filename="dataset.jsonl",
repo_type="dataset"
)
2️⃣ Create a simple mesh (no sharding)
NUM_TPUS = len(jax.devices())
MESH = [(1, NUM_TPUS), ("fsdp", "tp")] if NUM_TPUS > 1 else [(1, 1), ("fsdp", "tp")]
mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]))
3️⃣ Load the model and apply LoRA adapters
with mesh:
base_model = params_safetensors_lib.create_model_from_safe_tensors(
local_model_path, model_config, mesh
)
lora_provider = qwix.LoraProvider(
module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj",
rank=LORA_RANK,
alpha=LORA_ALPHA,
)
model_input = base_model.get_model_input()
model = qwix.apply_lora_to_model(
base_model,
lora_provider,
rngs=nnx.Rngs(0),
**model_input,
)
state = nnx.state(model)
pspecs = nnx.get_partition_spec(state)
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state)
4️⃣ Custom dataset for completion‑only loss
class CustomDataset:
def __init__(self, data, tokenizer, max_length=1024):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __iter__(self):
for item in self.data:
template_inputs = json.loads(item["text"])
prompt_and_completion = self.tokenizer.apply_chat_template(
template_inputs["messages"],
tools=template_inputs["tools"],
tokenize=False,
add_generation_prompt=False,
)
prompt_only = self.tokenizer.apply_chat_template(
template_inputs["messages"][:-1],
tools=template_inputs["tools"],
tokenize=False,
add_generation_prompt=True,
)
tokenized_full = self.tokenizer(prompt_and_completion, add_special_tokens=False)
tokenized_prompt = self.tokenizer(prompt_only, add_special_tokens=False)
full_ids = tokenized_full["input_ids"]
prompt_len = len(tokenized_prompt["input_ids"])
if len(full_ids) > self.max_length:
full_ids = full_ids[: self.max_length]
input_tokens = np.full(
(self.max_length,), self.tokenizer.pad_token_id, dtype=np.int32
)
input_tokens[: len(full_ids)] = full_ids
input_mask = np.zeros((self.max_length,), dtype=np.int32)
if len(full_ids) > prompt_len:
mask_end = min(len(full_ids), self.max_length)
input_mask[prompt_len:mask_end] = 1
yield peft_trainer.TrainingInput(
input_tokens=jnp.array(input_tokens, dtype=jnp.int32),
input_mask=jnp.array(input_mask, dtype=jnp.int32),
)
5️⃣ Data generators
def data_generator(split_data, batch_size):
dataset_obj = CustomDataset(split_data, tokenizer, MAX_LENGTH)
bat
(The snippet ends here as in the original source.)
batch_tokens, batch_masks = [], []
for item in dataset_obj:
batch_tokens.append(item.input_tokens)
batch_masks.append(item.input_mask)
if len(batch_tokens) == batch_size:
yield peft_trainer.TrainingInput(
input_tokens=jnp.array(np.stack(batch_tokens)),
input_mask=jnp.array(np.stack(batch_masks)),
)
batch_tokens, batch_masks = [], []
print("Preparing training data...")
train_batches = list(data_generator(train_data, BATCH_SIZE))
val_batches = list(data_generator(val_data_for_loss, BATCH_SIZE))
Kick‑off the finetuning
print("Starting Training...")
max_steps = len(train_batches) * NUM_EPOCHS
lr_schedule = optax.cosine_decay_schedule(
init_value=LEARNING_RATE, decay_steps=max_steps
)
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
log_dir=os.path.join(OUTPUT_DIR, "logs"), flush_every_n_steps=10
)
training_config = peft_trainer.TrainingConfig(
eval_every_n_steps=EVAL_EVERY_N_STEPS,
max_steps=max_steps,
checkpoint_root_directory=os.path.join(OUTPUT_DIR, "ckpts"),
metrics_logging_options=metrics_logging_options,
)
trainer = (
peft_trainer.PeftTrainer(model, optax.adamw(lr_schedule), training_config)
.with_gen_model_input_fn(gen_model_input_fn)
)
with mesh:
trainer.train(train_batches, val_batches)
print("Training Complete.")
The training takes a few minutes and Tunix is able to achieve a high TPU utilization rate during the run.

After one epoch of training, we see a significant boost in accuracy, demonstrating Tunix’s ability to drive qualitative improvements with minimal overhead.

When we are satisfied with the performance, we merge the LoRA adapters and export the finetuned model back to safetensors for downstream processing (e.g., on‑device deployment with LiteRT).
merged_output_dir = os.path.join(OUTPUT_DIR, "merged")
print(f"Saving merged LoRA model to {merged_output_dir}")
gemma_params.save_lora_merged_model_as_safetensors(
local_model_path=local_model_path,
output_dir=merged_output_dir,
lora_model=model,
rank=LORA_RANK,
alpha=LORA_ALPHA,
)
print("Model Exported Successfully.")
That’s the complete workflow for finetuning FunctionGemma with Tunix. As shown, Tunix is straightforward to use and can efficiently leverage Google TPUs. While this example covers supervised finetuning—the simplest approach—Tunix also supports more advanced techniques such as reinforcement learning, and we are actively adding further agentic training capabilities.
Conclusion
Tunix bridges the gap between research prototypes and production‑ready systems. Its modularity, JAX‑native speed, and breadth of supported algorithms make it an essential tool for any developer looking to polish their LLMs for specific tasks.
Check out the Tunix documentation to learn more, and follow the Tunix repository for updates.