以卓越性价比释放开放大模型潜能:TPU 上的推理优化全解

如果无法正常显示,请先停止浏览器的去广告插件。
分享至:
1. 以卓越性价比释放开放大模型潜能 TPU 上的推理优化全解 杨国强 : 谷歌云 / AI Infra 架构师
2. 目 录
3.
4. 01 TPU 简介与版本演进
5. TPU 的演进
6. 02 在 TPU 上部署开放模型的关键技术
7. 推理架构图 (框架: vllm/Max*/JetStream/Pathways) Inference Server Query HTTP / gRPC request User Application PyTorch Based LLM (llama, DeepSeek) Query Queue Scheduler vllm Engine Inference Batching JAX Based LLM (Gemma) HTTP / gRPC response Diffusion Jetstream Model (MaxText) Query Response MaxDiffusion Metrics Hardware (TPU)
8. vllm 简介 ● 目标: ○ ● 优化 KV 缓存内存利用率,实现高吞吐量模型服务 洞见: ○ 使用虚拟内存分页以减少内存浪费
9. vllm on TPU 时间线 Executor Performance Prefill Decode Bottleneck Prefill token redundancy Slow KV cache writes No parallelism in prefill/decode Optimization Multi-query Attention (TPU-friendly) Sparsecore offload + Cache Layout Redesign Ragged Paged Attention (TPU-friendly) Benefit Prefix caching Faster prefill/decode Chunked prefill 新特性: Multi-Lora Available Available Available Multimodal model support Available Available + Prometheus Metrics Available V1 refactor Available Quantized models Available
10.
11. APC (自动前缀缓存) 重复利用先前生成的计算结果,消除冗余处理。 缓存在分层的内存中(HBM/主机内存)。 针对常见提示词前缀的重复请求减小 TTFT 提高吞吐量并最终达到降本增效的目的 典型的应用场景: 多轮对话,针对技术文档/书籍或者视频的重复问答 缓存命中的业务场景下显著提升系统吞吐量 https://github.com/GoogleCloudPlatform/vertex-ai- samples/blob/main/notebooks/community/model_garden/model_garden_advanced_features.ipynb
12. 分块预填充(Chunked Prefill)
13. 连续批处理(Continuous Batching)
14.
15. 03 TPU 硬件特性 —和 GPU 有什么不同
16. v6e pod 支持万亿规模参数模型
17.
18.
19.
20.
21. Google’s high-speed interconnect makes Cloud TPU Pods "AI Supercomputer" All-reduce by hardware As easy as using a single computer
22. Versatile and high performance model serving Range of shapes to match model size and complexity 16 chips VM Shape Chips/Slice Serves* 1 1 13 B param LLM 2x2 4 32.5 B param LLM 2x4 8 65 B param LLM 4x4 16 170 B param LLM 4x8 32 280 B param LLM 8x8 64 540 B param LLM 8x16 128 1 T param LLM 16x16 256 2 T param LLM * Google Internal Data. August 2023. Batch size = 1. Multi-head attention based decoder only language models: prefix length = 2048, decode steps = 256, beam size = 32 for sampling
23. Optimized on vLLM from Day 1 Ironwood 7th Generation TPU 5X 6X 2X More peak performance per chip More HBM capacity per chip More power efficient * vs Trillium
24. At a glance: TPU v5p, Trillium, Ironwood Custom-designed AI Supercomputers v5p Trillium v6e Ironwood TPU7x (vs. v5p, v6e) Number of chips per pod 8960 256 9216 Chip Bf16 TFLOPs 459 918 2307 (5x, 2.5x) Chip FP8 TFLOPs 459 (SW emulated) 918 (SW emulated) 4614 (10x, 5x ) HBM/chip (GB) 95 32 192 (2x, 6x) HBM BW (GB/s) 2,765 1640 7380 (2.7x, 4.5x) VMEM (MB) 128 128 128 (1x, 1x) Host DRAM (GB) 512 1536 1152 (2.25x, 0.75x) ICI BW per chip (GB/s) 1,200 bi-dir 800 bi-dir 1,200 bi-dir (1x, 1.5x) DCN BW per chip (Gb/s) 50 100 100 (2x, 1x) SparseCore (Embeddings & Offloads) Yes Yes Yes GA H1, 2024 *Google Confidential, Subject to Change GA H2’24 GA Q4’25
25. 04 TPU 软件框架 ——和 GPU 有什么不同
26. JAX vs PyTorch OSS Open Source Closed Source JAX on TPU MaxText Flax Ref. Implement. MaxDiffusion Optax Orbax Pallas libtpu.so Megatron-LM Framework & Eco. Sys. JAX Core OpenXLA PyTorch on GPU Compiler & Kernels HF Diffuser PyTorch CUDA Kernel Triton Driver & Runtime CUDA Driver Accelerator NVidia GPU JAX Stable Stack Docker TPU
27. 支持用 pallas 手写算子优化性能 >_ User Code JIT XLA HLO Target- independent Optimizations Mosaic / Triton XLA HLO Target-dependent Optimizations Target-specific code generation TPU Optimized
28. Torch/XLA -> Torchax -> next 1 -> next 2 ● 目前在TPU/XLA设备上运行PyTorch程序 的最佳选项 ● PyTorch使用dispatcher将对torch ops的 调用转发到该ops的实际实现上,我们最 熟悉的通常是CUDA内核 ● “XLA”是 PyTorch dispatcher 调度程序处 理的特殊调度键,在PyTorch的源代码中 hard-code的。所以可以利用 dispatch table中的次序实现所有ops的mapping. ● 捕获计算图中对应的torch/aten操作 ● lower* 每个aten ops到xla ops
29. 05 案例分享与总结
30. 时间线 6/9 receive requirement 1 6/27 Wan2.1 14B on v6e-16 in 257s 6/30 Wan2.1 14B on v6e-32 in 150s 6.14 diffusers as base repo 2 ~6/13 Wan2.1 1.3B migration in 1w 3 4 6/23 Wan2.1 14B works on v6e-8 in 675s 5 7/10 Wan2.1 14B on v6e-16 in 142s 6 7/8 Wan2.1 14B on v6e-16 in 171s 7
31. Diffusion models are NOT single models Noisy latents T5 Text encoders Embeddings Diffusion Transformer Timestep “A cat looking like a tiger” Scheduler Refined latents VAE Decoder
32. diffusers from diffusers import DiffusionPipeline from diffusers import EulerDiscreteScheduler pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) pipeline.to("cuda") image = pipeline("An image of a squirrel in Picasso style") from diffusers.utils import export_to_video from diffusers import AutoencoderKLWan as TorchAutoencoderKLWan, WanPipeline from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler ./src/diffusers/pipelines/wan/pipeline_wan.py ./src/diffusers/pipelines/wan/pipeline_wan_video2video.py ./src/diffusers/pipelines/wan/pipeline_wan_i2v.py ./src/diffusers/models/transformers/transformer_wan.py ./src/diffusers/models/autoencoders/autoencoder_kl_wan.py
33. WAN2.1 14B Total hbm usage >= 37.11G: reserved 260.00M program 10.15G arguments 26.70G Largest program allocations in hbm: 1. Size: 2.89G Operator: op_name="aten::view/reshape" source_file="jaten.py" source_line=55 Shape: f32[2,75600,40,128]{1,2,3,0:T(8,128)} Unpadded size: 2.88G Extra memory due to padding: 1.88M (1.0x expansion) XLA label: copy.1253 = copy(bitcast.1528) Allocation type: HLO temp ==========================
34. Sharding 基础
35. 优化工具 profiler
36. 内容总结 ● TPU 是针对深度学习定制化的硬件 ● 语言最佳方案 LLM vllm on TPU 部署开放语言模型 ● 文生图最佳方案用 MaxDiffusion 部署文生图模型 ● TPU 独特硬件特性带来更好性价比 ● 欢迎大家在类似业务场景使用 TPU
37.
38. THANKS 探索 AI 应用边界 Explore the limits of AI applications

Home - Wiki
Copyright © 2011-2025 iteam. Current version is 2.146.0. UTC+08:00, 2025-10-23 07:16
浙ICP备14020137号-1 $Map of visitor$