CogView:文本到图像生成预训练
如果无法正常显示,请先停止浏览器的去广告插件。
1. CogView :
Mastering Text-to-Image Generation
via Transformers
NeurIPS 2021, Virtual Conference
Ming Ding
In collaboration with: Zhuoyi Yang, Wenyi Hong, Wendi Zheng, Chang Zhou,
Da Yin, Junyang Lin, Xu Zou, Zhou Shao, Hongxia Yang, Jie Tang
Tsinghua University, Alibaba Group and BAAI
2. Domain-general text-to-image generation is challenging…
Domain-specific Domain-specific Domain-general
unconditional generation text-to-image generation text-to-image generation
• StyleGAN2 on FFHQ
• DF-GAN on CUB
• This bird has a white belly and
breast, with a blue crown and nape.
• DF-GAN on COCO
• Close-up of a man eating a piece
of pizza while holding a plate.
Karras, Tero, et al. "Analyzing and improving the image quality of stylegan." CVPR. 2020.
Tao, Ming, et al. "Df-gan: Deep fusion generative adversarial networks for text-to-image synthesis." arXiv . 2020.
3. Does heuristic pipelines help?
A girl riding
Riverside
An oil painting of
“A girl is riding a
unicorn at the
riverside.”
Recall
compose
single
objects
Oil
painting
style
unicorn
• Obj-GAN [Li et al., CVPR 2019]
• RetrieveGAN [Tseng et al., ECCV 2020]
What about “a stained
glass window with an
image of a fox ”?
4. What we really need is a powerful generative model!
• The most powerful generative model in NLP is the autoregressive
transformer (GPT).
• GPT can only deal with tokens in a discrete dictionary.
• VQVAE can compress an image into discrete tokens.
• Directly concatenate text and image tokens for language modeling.
Radford, Alec et al. "Language Models are Unsupervised Multitask Learners." OpenAI blog. 2019.
Oord, Aaron van den et al. " Neural discrete representation learning." NeurIPS 2017.
5. What we really need is a powerful generative model!
Text Tokenizer (sentence pieces)
Image Tokenizer
(Discrete AutoEncoder)
Discretize
Recover
Input Image:
Input Text:
Flattern
[ROI1]
Text Token
<latexit sha1_base64="WkmkOQqV4y/G2CwEGjey+GFekFc=">AAACAnicbVDLSgMxFM3UV62vUVfiJlgEV2VGi7osuHFZwT6gM5RMeqcNzWSGJCOUobjxV9y4UMStX+HOvzHTzkJbD4Qczrn3JvcECWdKO863VVpZXVvfKG9WtrZ3dvfs/YO2ilNJoUVjHstuQBRwJqClmebQTSSQKODQCcY3ud95AKlYLO71JAE/IkPBQkaJNlLfPvJiYweSUMi8kUry+9JJ9HTat6tOzZkBLxO3IFVUoNm3v7xBTNMIhKacKNVzzRw/I1IzymFa8VIFZv6YDKFnqCARKD+brTDFp0YZ4DCW5giNZ+rvjoxESk2iwFRGRI/UopeL/3m9VIfXfsZEkmoQdP5QmHKsY5zngQdMAtV8Ygihkpm/YjoiJg9tUquYENzFlZdJ+7zmXtScu3q1US/iKKNjdILOkIuuUAPdoiZqIYoe0TN6RW/Wk/VivVsf89KSVfQcoj+wPn8A712XuA==</latexit>
z
Text Token
}|
{
[BASE]
[BOI1]
Image Token
<latexit sha1_base64="WkmkOQqV4y/G2CwEGjey+GFekFc=">AAACAnicbVDLSgMxFM3UV62vUVfiJlgEV2VGi7osuHFZwT6gM5RMeqcNzWSGJCOUobjxV9y4UMStX+HOvzHTzkJbD4Qczrn3JvcECWdKO863VVpZXVvfKG9WtrZ3dvfs/YO2ilNJoUVjHstuQBRwJqClmebQTSSQKODQCcY3ud95AKlYLO71JAE/IkPBQkaJNlLfPvJiYweSUMi8kUry+9JJ9HTat6tOzZkBLxO3IFVUoNm3v7xBTNMIhKacKNVzzRw/I1IzymFa8VIFZv6YDKFnqCARKD+brTDFp0YZ4DCW5giNZ+rvjoxESk2iwFRGRI/UopeL/3m9VIfXfsZEkmoQdP5QmHKsY5zngQdMAtV8Ygihkpm/YjoiJg9tUquYENzFlZdJ+7zmXtScu3q1US/iKKNjdILOkIuuUAPdoiZqIYoe0TN6RW/Wk/VivVsf89KSVfQcoj+wPn8A712XuA==</latexit>
z
Image Token
}|
{
Transformer (GPT)
• 48 layers, hidden size 2560, 40 heads. (about 4B parameters in total)
• 30M text-image pairs, 512 V100 * 427 hours.
[EOI1]
6. Theory
CogView本质是在优化图片和文本联合似然的ELBO:
普通VAE:固定先验分布(通常为标准Gaussian),通过编码器的近似后验分布优化KL项
VQ-VAE:后验分布只用来优化重构误差项,在KL项固定,通过优化先验(自回归模型)来优化KL项
离散化之后进一步得到:
7. Samples
A blonde is on the phone
A Big Ben over London.
Super-resolution mid-lake pavilion
A couple in leather motorcycle
suits are riding a motorcycle. A coffee cup printed with
a cat. Sky background. A tiger is playing football. A man is flying to the
moon on his bicycle
Chinese traditional draw‐
ing. Statue of Liberty. Oil painting. Lion. Cartoon. A tiger is playing
football. Sketch. Houses.
8. CogView Generation
孤独的氛围
冬日暖阳里的乌鸦
9. Transformer is just the beginning…
Challenges:
• (Tokenization) What is the best way to train the VQVAE (stage 1)?
• Straight-through? Gumbel-softmax? Moving averaging? Fixed codebook?
• (Instability) Quickly encounter value explosion (NaN losses).
• FP16 + dynamic loss scaler
• Didn’t find this phenomenon in ordinary NLP pretraining
• Didn’t find this phenomenon in pretraining small (< 1B parameters) models
• (Resolution) Unaffordable expense (time and memory) for higher resolution.
• We use 32*32 tokens, what about 64*64?
• (Post-selection) How to select the best one from many generated samples?
10. Comparison of different training methods for VQVAE
• The training of VQVAE is robust
enough. *
• All training methods will converge
to a similar loss level, even if the
embeddings in the codebook are
fixed after initialization.
*Our results are quite different from an (informal) conclusion from a famous researcher Andrej Karpathy at
https://github.com/karpathy/deep-vector-quantization, who says the training exhibits catastrophic index collapse.
11. The Instability of Pretraining: Phenomenon and Reasons
Direct reason:
• Values in the hidden representations will be
magnified layer-by-layer.
• Overflow in attention or LayerNorm.
• Happen with FP16.
• The NaN loss results from value explosion.
• Will reduce loss scaler, and then lead to an
incorrect update due to the underflow of
gradients.
Roots, relevant to:
• the precision.
• the heterogeneity of data.
• The model size (just hide the problem).
Still an open problem?
At least we find the solution…
12. Instability in Pretraining: Solutions
• Sandwich LayerNorm
• Block up the layer-wise
magnification.
• Precision-Bottleneck relaxation
• Make the computation of attention
and LayerNorm precision-friendly.
DALL-E uses another solution, perResblock loss scaling and throw back some parameters back to fp32.
13. Local Attention in CogView1
All texts and some random “pivot” attention
+ Blockwise window attention
Text
Text
Text
Text
Text
Text
Text
Text
• To approximate full attention
• Local? Image-to-text attention is non-
local
• 3 parts
• Local (Longformer)
• Random (Bigbird)
• Fixed text postitions
• 2.5x faster, save 40% memory
• 256 window size
• The same performance
O
• Hard to implement? No.
14. Pure Pytorch Local attention in CogView
• Use manual stride to
build a non-overlapping
view.
• But don’t work in
backward
• Still need customized
kernel to save memory
• In CogView2
15. Towards higher resolution
– Do we really need to pretrain with a long sequence?
The Big Ben
The close-up view • Since CogView is trained on the most complex
distribution of domain-general images, details of the
of the Big Ben
objects have already been covered.
•
1
2
3
4
5
6
7 8 9
Finetuning it into an super-resolution model should not
be hard. (16*16 tokens => 32*32 tokens, 1 DGX*day)
16. How to do post-selection?
• Baseline: we can train another text-image similarity model based on
contrastive learning, for example CLIP.
• However, we have already trained a cross-modal self-supervised model,
CogView!
• Knowledge can be reused.
• Just finetune CogView in the inverse order (image-to-text).
• 1 DGX*day
• Evaluate the similarity by P(text | image)
• Inverse prompting [Xu Zou et al., KDD 2021]
17. Experiments: Zero-shot text-to-image generation on MS COCO
• Outperform DALL-E even without super-resolution.
18. Experiments: Post-selection
Lower FID
• Our self-reranking gets better FID that CLIP
when reranking images generated by CogView.
• An example of reranking 60 images of “a man
wearing red shirt is playing video game”
Bad cases
19. Experiments: Human Evaluation
CogView is very competitive with recovered (from VQVAE) ground truth.
(37% vs. 59% votes)
20.
21. Limitations
• Slow generation
• A drawback for all auto-regressive models
• CogView is large.
• Sliding window for 64*64… at least 4 × time
• Blurriness and artifacts
• Due to Vector-Quantization
• Correspondence between text and image
• All the 3 limitations are advanced in CogView2.
22. Conclusion
• Text-image GPT
• Stabilization:
• Sandwich-LN & PB-relax
Codes and Models:
https://github.com/THUDM/CogView
Demo website (Latest version) :
http://wudao.aminer.cn/CogView/index.html
• Finetuning
• High-resolution
• self-reranking
• More about CogView (See paper)
• Style finetuning
• Fashion design
• Three-region sparse attention
The Logo and images for illustration in
the slides are all generated by CogView2.