以卓越性价比释放开放大模型潜能:TPU 上的推理优化全解
如果无法正常显示,请先停止浏览器的去广告插件。
1. 以卓越性价比释放开放大模型潜能
TPU 上的推理优化全解
杨国强
: 谷歌云 / AI Infra 架构师
2. 目
录
3.
4. 01
TPU 简介与版本演进
5. TPU 的演进
6. 02 在 TPU 上部署开放模型的关键技术
7. 推理架构图 (框架: vllm/Max*/JetStream/Pathways)
Inference Server
Query
HTTP / gRPC request
User
Application
PyTorch Based LLM
(llama,
DeepSeek)
Query
Queue
Scheduler
vllm Engine
Inference
Batching
JAX Based LLM (Gemma)
HTTP / gRPC
response
Diffusion
Jetstream Model
(MaxText)
Query
Response MaxDiffusion
Metrics
Hardware (TPU)
8. vllm 简介
●
目标:
○
●
优化 KV 缓存内存利用率,实现高吞吐量模型服务
洞见:
○
使用虚拟内存分页以减少内存浪费
9. vllm on TPU 时间线
Executor Performance
Prefill
Decode
Bottleneck Prefill token redundancy Slow KV cache writes No parallelism in
prefill/decode
Optimization Multi-query Attention
(TPU-friendly) Sparsecore offload +
Cache Layout Redesign Ragged Paged Attention
(TPU-friendly)
Benefit Prefix caching Faster prefill/decode Chunked prefill
新特性:
Multi-Lora
Available
Available
Available
Multimodal model support
Available
Available
+ Prometheus Metrics
Available
V1 refactor
Available
Quantized models
Available
10.
11. APC (自动前缀缓存)
重复利用先前生成的计算结果,消除冗余处理。
缓存在分层的内存中(HBM/主机内存)。
针对常见提示词前缀的重复请求减小 TTFT
提高吞吐量并最终达到降本增效的目的
典型的应用场景:
多轮对话,针对技术文档/书籍或者视频的重复问答
缓存命中的业务场景下显著提升系统吞吐量
https://github.com/GoogleCloudPlatform/vertex-ai-
samples/blob/main/notebooks/community/model_garden/model_garden_advanced_features.ipynb
12. 分块预填充(Chunked Prefill)
13. 连续批处理(Continuous Batching)
14.
15. 03 TPU 硬件特性
—和 GPU 有什么不同
16. v6e pod 支持万亿规模参数模型
17.
18.
19.
20.
21. Google’s high-speed interconnect makes
Cloud TPU Pods "AI Supercomputer"
All-reduce
by hardware
As easy as using
a single computer
22. Versatile and high performance model serving
Range of shapes to match model size and complexity
16 chips
VM Shape Chips/Slice Serves*
1 1 13 B param LLM
2x2 4 32.5 B param LLM
2x4 8 65 B param LLM
4x4 16 170 B param LLM
4x8 32 280 B param LLM
8x8 64 540 B param LLM
8x16 128 1 T param LLM
16x16 256 2 T param LLM
* Google Internal Data. August 2023. Batch size = 1. Multi-head attention based decoder only language models: prefix length = 2048, decode steps = 256, beam size = 32 for sampling
23. Optimized on vLLM from Day 1
Ironwood
7th Generation TPU
5X
6X
2X
More peak performance per
chip
More HBM capacity per chip
More power efficient
* vs Trillium
24. At a glance: TPU v5p, Trillium, Ironwood
Custom-designed AI Supercomputers
v5p Trillium v6e Ironwood TPU7x (vs. v5p, v6e)
Number of chips per pod 8960 256 9216
Chip Bf16 TFLOPs 459 918 2307 (5x, 2.5x)
Chip FP8 TFLOPs 459 (SW emulated) 918 (SW emulated) 4614 (10x, 5x )
HBM/chip (GB) 95 32 192 (2x, 6x)
HBM BW (GB/s) 2,765 1640 7380 (2.7x, 4.5x)
VMEM (MB) 128 128 128 (1x, 1x)
Host DRAM (GB) 512 1536 1152 (2.25x, 0.75x)
ICI BW per chip (GB/s) 1,200 bi-dir 800 bi-dir 1,200 bi-dir (1x, 1.5x)
DCN BW per chip (Gb/s) 50 100 100 (2x, 1x)
SparseCore
(Embeddings & Offloads) Yes Yes Yes
GA H1, 2024
*Google Confidential, Subject to Change
GA H2’24
GA Q4’25
25. 04 TPU 软件框架
——和 GPU 有什么不同
26. JAX vs PyTorch OSS
Open Source
Closed Source
JAX on TPU
MaxText
Flax
Ref.
Implement.
MaxDiffusion
Optax
Orbax
Pallas
libtpu.so
Megatron-LM
Framework
& Eco. Sys.
JAX Core
OpenXLA
PyTorch on GPU
Compiler &
Kernels
HF Diffuser
PyTorch
CUDA
Kernel
Triton
Driver &
Runtime CUDA Driver
Accelerator NVidia GPU
JAX Stable Stack Docker
TPU
27. 支持用 pallas 手写算子优化性能
>_ User Code
JIT
XLA HLO
Target-
independent
Optimizations
Mosaic / Triton
XLA HLO
Target-dependent
Optimizations
Target-specific
code generation
TPU
Optimized
28. Torch/XLA -> Torchax -> next 1 -> next 2
● 目前在TPU/XLA设备上运行PyTorch程序
的最佳选项
● PyTorch使用dispatcher将对torch ops的
调用转发到该ops的实际实现上,我们最
熟悉的通常是CUDA内核
● “XLA”是 PyTorch dispatcher 调度程序处
理的特殊调度键,在PyTorch的源代码中
hard-code的。所以可以利用 dispatch
table中的次序实现所有ops的mapping.
● 捕获计算图中对应的torch/aten操作
● lower* 每个aten ops到xla ops
29. 05 案例分享与总结
30. 时间线
6/9 receive
requirement
1
6/27 Wan2.1 14B
on v6e-16 in 257s
6/30 Wan2.1 14B
on v6e-32 in 150s
6.14
diffusers as base repo
2
~6/13 Wan2.1 1.3B
migration in 1w
3
4
6/23 Wan2.1 14B works
on v6e-8 in 675s
5
7/10 Wan2.1 14B
on v6e-16 in 142s
6
7/8 Wan2.1 14B on
v6e-16 in 171s
7
31. Diffusion models are NOT single models
Noisy latents
T5 Text
encoders
Embeddings
Diffusion
Transformer
Timestep
“A cat looking like a tiger”
Scheduler
Refined
latents
VAE
Decoder
32. diffusers
from diffusers import DiffusionPipeline
from diffusers import EulerDiscreteScheduler
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
pipeline.to("cuda")
image = pipeline("An image of a squirrel in Picasso style")
from diffusers.utils import export_to_video
from diffusers import AutoencoderKLWan as TorchAutoencoderKLWan, WanPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
./src/diffusers/pipelines/wan/pipeline_wan.py
./src/diffusers/pipelines/wan/pipeline_wan_video2video.py
./src/diffusers/pipelines/wan/pipeline_wan_i2v.py
./src/diffusers/models/transformers/transformer_wan.py
./src/diffusers/models/autoencoders/autoencoder_kl_wan.py
33. WAN2.1 14B
Total hbm usage >= 37.11G:
reserved
260.00M
program
10.15G
arguments
26.70G
Largest program allocations in hbm:
1. Size: 2.89G
Operator: op_name="aten::view/reshape" source_file="jaten.py" source_line=55
Shape: f32[2,75600,40,128]{1,2,3,0:T(8,128)}
Unpadded size: 2.88G
Extra memory due to padding: 1.88M (1.0x expansion)
XLA label: copy.1253 = copy(bitcast.1528)
Allocation type: HLO temp
==========================
34. Sharding 基础
35. 优化工具 profiler
36. 内容总结
● TPU 是针对深度学习定制化的硬件
● 语言最佳方案 LLM vllm on TPU 部署开放语言模型
● 文生图最佳方案用 MaxDiffusion 部署文生图模型
● TPU 独特硬件特性带来更好性价比
● 欢迎大家在类似业务场景使用 TPU
37.
38. THANKS
探索 AI 应用边界
Explore the limits of AI applications