TPU vs GPU: Real-World Performance Testing for LLM Training on Google Cloud
Source: Dev.to
Introduction
As Large Language Models (LLMs) continue to grow in scale, the underlying hardware used for training has become the single most critical factor in a project’s success. The industry is currently locked in a fascinating architectural battle: the general‑purpose power of NVIDIA’s GPUs versus the purpose‑built efficiency of Google’s Tensor Processing Units (TPUs).
For engineers and architects building on Google Cloud Platform (GCP), the choice between an A100/H100 GPU cluster and a TPU v4/v5p pod is not merely a matter of cost—it directly impacts software architecture, data pipelines, and convergence speed. This article provides a deep‑dive technical analysis of these two architectures through the lens of real‑world LLM training performance.
Silicon‑Level Differences
The fundamental difference lies in how the chips handle matrix multiplication, the core operation of the Transformer architecture.
| Aspect | NVIDIA GPUs | Google TPUs |
|---|---|---|
| Design philosophy | Many‑core general‑purpose processors with a hierarchy of Streaming Multiprocessors (SMs) and specialized Tensor Cores. | Domain‑Specific Architecture (DSA) built around a systolic‑array design. |
| Memory hierarchy | Complex hierarchy (L1/L2 caches, shared memory) orchestrated via CUDA kernels. | Simplified flow through a grid of processing elements, minimizing register‑file and external‑memory accesses. |
| Core operation | Flexible for graphics, simulations, and neural networks. | Optimized for massive, deterministic matrix multiplications. |
Cluster‑Level Communication
Training an LLM like Llama‑3 or GPT‑4 is never done on a single chip; it’s performed on a cluster. The speed of inter‑chip communication often outweighs raw TFLOPS.
-
NVIDIA
- NVLink/NVSwitch: Intra‑node communication.
- InfiniBand: Inter‑node communication.
- H100 supports NVLink 4, delivering ≈ 900 GB/s bandwidth.
-
Google TPUs
- Optical Circuit Switch (OCS) with a proprietary Inter‑Core Interconnect (ICI).
- TPU v4 and v5p leverage OCS to dynamically reconfigure the pod topology, forming a massive 3‑D torus that provides low‑latency, high‑bandwidth communication across thousands of chips without the overhead of traditional networking layers.
Feature Comparison
| Feature | NVIDIA H100 (SXM5) | Google TPU v5p |
|---|---|---|
| Architecture | Hopper (General Purpose) | Systolic Array (DSA) |
| Memory | 80 GB HBM3 | 95 GB HBM3 |
| Memory Bandwidth | 3.35 TB/s | 4.8 TB/s |
| Interconnect | NVLink 4.0 / InfiniBand | ICI / Optical Circuit Switch |
| Primary Software | CUDA, PyTorch | XLA, JAX, PyTorch |
Real‑World Test Setup
We conducted a training run of a 7 B‑parameter Transformer model (Llama‑2 architecture) on Google Cloud.
| Test Configuration | Details |
|---|---|
| GPU Cluster | 8 × NVIDIA H100 (80 GB) nodes connected via GPUDirect‑TCPX |
| TPU Pod | TPU v5p‑8 (8 cores) and TPU v5p‑32 (32 cores) slices |
| Software Stack | Both platforms benefit from XLA (Accelerated Linear Algebra). While XLA is native to TPUs, OpenXLA enables PyTorch and JAX code to be compiled efficiently for both GPUs and TPUs. TPUs require XLA; GPUs can also run in “eager mode.” |
| Preferred Framework on TPUs | JAX, due to its functional approach that maps naturally onto the systolic array. |
Example JAX Sharding Code (runs on both TPU pods and multi‑GPU setups)
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
# Detect devices (TPU or GPU)
devices = jax.devices()
print(f"Devices found: {devices}")
# Define a 2‑D mesh for model and data parallelism
# Works identically on TPU pods and multi‑GPU setups
device_mesh = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
# Create a sharded array
# 'data' axis shards the batch, 'model' axis shards the weights
sharding = NamedSharding(mesh, PartitionSpec('data', 'model'))
def train_step(state, batch):
# XLA handles the communication primitives (all‑reduce)
# during the gradient computation automatically
def loss_fn(params):
logits = model.apply(params, batch['input'])
return jnp.mean(cross_entropy(logits, batch['target']))
grads = jax.grad(loss_fn)(state.params)
return state.apply_gradients(grads=grads)
# JIT‑compile the step for XLA optimization
parallel_train_step = jax.jit(train_step)
Performance Results
| Metric | NVIDIA H100 | Google TPU v5p |
|---|---|---|
| Throughput (tokens / sec / chip) | ~3,800 | ~3,450 |
| Model FLOPs Utilization (MFU) | ~52 % | ~58 % |
| Observations | Higher raw per‑chip throughput for smaller batches due to higher clock speeds and versatile cache. | Superior MFU and memory bandwidth become evident as batch size scales (≥ 1 M tokens). |
The TPU’s deterministic execution and ICI interconnect minimize idle time, leading to higher overall utilization despite a slightly lower raw throughput per chip.
Distributed Training Strategies
| Strategy | GPU Implementation | TPU Implementation |
|---|---|---|
| Data Parallelism | torch.distributed with NCCL | Handled automatically by the GSPMD compiler (XLA) |
| Model Parallelism (Tensor, Pipeline, Sequence) | Manual sharding via PyTorch APIs | GSPMD (General Shard‑Man Parallel Multi‑Device) lets developers write code for a single device; the compiler inserts the necessary sharding logic across the pod. |
Cost Considerations
Performance must be weighed against cost:
- Google Cloud TPU pricing is generally lower than H100 pricing for equivalent compute time.
- Spot TPUs can be up to 70 % cheaper than on‑demand instances.
- GPUs also offer Spot instances, but price differentials and availability vary by region and demand.
Takeaways
- Raw throughput vs. utilization – H100s lead on small‑batch, per‑chip speed; TPUs excel in sustained utilization at scale.
- Interconnect matters – TPU’s optical circuit switch provides a topology that scales more gracefully across thousands of chips.
- Software ecosystem – Both platforms now support XLA; JAX is the natural fit for TPUs, while PyTorch with NCCL remains the standard on GPUs.
- Cost efficiency – Spot TPUs often deliver the best price‑performance ratio for large‑scale LLM training on GCP.
Choosing the right hardware ultimately depends on your workload’s batch size, desired training speed, and budget constraints. By understanding the architectural nuances outlined above, you can make an informed decision that aligns with both performance goals and cost targets.
Availability & Cost Comparison
The availability of large contiguous blocks of H100 GPUs is often lower than that of TPU slices.
Example Cost Comparison (estimated hourly for an 8‑chip node)
| Configuration | Spot / Reserved Cost (≈) |
|---|---|
| 8× H100 Node | $12.00 – $15.00 |
| TPU v5p‑8 Slice | $8.00 – $11.00 |
When calculating Tokens per Dollar, the TPU v5p consistently outperformed the H100 by 15–25 % in our training runs, despite the H100 having slightly higher raw throughput. This makes TPUs the preferred choice for long‑running pre‑training stages where budget is a primary constraint.
When GPUs Still Shine
- Ecosystem & Flexibility – Most open‑source ML research is written first for CUDA. Niche libraries or brand‑new attention mechanisms (e.g., FlashAttention‑3) are usually optimized for NVIDIA first.
- Torch‑XLA allows PyTorch to run on TPUs, but it often requires minor code changes to avoid “context switching” between the CPU and the TPU, which can kill performance.
- Debugging – XLA code is compiled, so you can’t simply place a
printstatement inside your training loop. Usejax.debug.printor the Cloud TPU profiler to identify bottlenecks such as HBM stalls or Infeed queues.
Common TPU Bottleneck: Infeed
When using the TPU, a frequent limitation is the Infeed, where the CPU cannot supply data fast enough to keep the TPU busy.
# Using the TPU Profiler in a training loop
import jax
with jax.profiler.trace("/tmp/tpu_profile", create_perfetto_link=True):
for i in range(100):
state = parallel_train_step(state, next(data_iter))
# Ensure the TPU doesn't wait for the host
if i % 10 == 0:
print(f"Step {i} completed")
Decision Tree for LLM Training on Google Cloud
| Scenario | Recommended Accelerator | Why |
|---|---|---|
| Scale is Massive – pre‑training from scratch across hundreds or thousands of chips | TPU v5p | Superior inter‑chip bandwidth (OCS, ICI) and linear scaling |
JAX/XLA Compatibility – codebase in JAX or comfortable with torch_xla | TPU v5p | Native XLA compilation |
| Cost Sensitivity – need the best “Tokens per Dollar” and can use Spot instances | TPU v5p | Lower cloud pricing, higher utilization |
| Standard Architectures – vanilla Transformer blocks (Attention, MLP, LayerNorm) | TPU v5p | Highly optimized in the XLA compiler |
| Bleeding‑Edge Research – custom CUDA kernels or non‑standard layers lacking XLA support | GPU H100 | CUDA‑first ecosystem |
| Fast Prototyping – eager‑mode PyTorch for quick debugging | GPU H100 | Easier, more interactive development |
| Small‑Scale Fine‑tuning – single‑node (8 GPUs) workloads | GPU H100 | Faster setup, greater flexibility |
| Multi‑Cloud Strategy – portability across AWS, Azure, GCP | GPU H100 (or TPU with abstraction) | Less backend‑specific code changes |
The “TPU vs GPU” debate is no longer about raw speed—it’s about system‑level efficiency for your specific workload.
Summary of Strengths
| Metric | Winner | Reason |
|---|---|---|
| Raw Throughput (Single Node) | GPU H100 | Higher clock speeds and dedicated Transformer Engine |
| Scalability (Multi‑Node) | TPU v5p | Optical Circuit Switch (OCS) and Inter‑Chip Interconnect (ICI) give superior bandwidth |
| Cost per Token | TPU v5p | Lower cloud pricing and higher hardware utilization |
| Developer Velocity | GPU H100 | Massive community support and easier debugging |
| Framework Support | Tie | Both support PyTorch/JAX (GPU natively, TPU via XLA) |
| Future‑Proofing | GPU H100 | CUDA support ensures compatibility with emerging research |
By carefully evaluating your model architecture and budget, you can choose the right accelerator to keep your LLM training project on track and within budget.
Further Reading & Resources
- Technical Guides – Google Cloud AI Architecture & Implementation
- Follow us:
- Twitter / X
- GitHub