2025-09-02
391 words • ~2 min
Self Forcing: Bridging the Train-Test Gap in Autoregressive Video Diffusion
- introduces novel paradigm for autogressive diffusion models
- addresses exposure bias (models trained on ground truth context must generate sequences conditioned on their own imperfect predictions at inference time)
- design of Diffusion Transformers (DiTs) till date denoise all frames simultaneously using bidirectional attention (future frames can affect past frames, and entire video must be generated at once) - authors argue that this is fundamentally limiting their applicability for real-time applications
- autoregressive models may help, but usually struggle to match sota video model performance
- teacher forcing recap: predict next token conditioned on ground-truth tokens. in context of video diffusion, TF trains model to predict each frame conditioned on previous context frames
- self forcing: instead of conditioning on ground-truth frames, model conditions on its own previous predictions, allowing it to learn to generate more realistic frames over time $$p(\hat{x}^1)p(\hat{x}^2|x^1)p(\hat{x}^3|x^1,x^2)...p(\hat{x}^T|x^{\lt T}) = p(\hat{x}^1, \hat{x}^2, ..., \hat{x}^{T-1})$$
- diffusion forcing: trains model on videos with noise levels independently sampled for each frame (denoising each frame based on past noisy context frames) - ensures autoregressive inference scenario where context frames are clean and current frame is noisy $$p(\hat{x}^1)p(\hat{x}^2|x_{t^1}^1)p(\hat{x}^3|x_{t^1}^1, x_{t^2}^2)...p(\hat{x}^T|x_{t^{\lt T}}^{\lt T}) \neq p(\hat{x}^1, \hat{x}^2, ..., \hat{x}^{T-1})$$
- introduces "self forcing": addresses exposure bias. inspired by RNN-era sequence modeling techniques to bridge train-test distribution gap by explicitly unrolling autoregressive generation during training.
- each frame is conditioned on previously self-generated frames rather than ground truth frames
- supervision with distribution-matching losses $$p(\hat{x}^1)p(\hat{x}^2|\hat{x}^1)p(\hat{x}^3|\hat{x}^1,\hat{x}^2)...p(\hat{x}^T|\hat{x}^{\lt T}) = p(\hat{x}^1, \hat{x}^2, ..., \hat{x}^{T-1})$$
todo reading and re-reading:
- Improved Distribution Matching Distillation for Fast Image Synthesis
- One-step Diffusion with Distribution Matching Distillation
- Score Identity Distillation
- Generative Adversarial Networks
reading list for exposure bias - some approaches attempt to mitigate distributional mismatch by incorporating noisy context frames during inference:
- Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion
- Oasis: A Universe in a Transformer
- Packing Input Frame Context in Next-Frame Prediction Models for Video Generation
reading list for distribution-matching losses: