← Back to Reading Log
Distilling Step-by-Step! Outperforming Larger Language Models with Less Training Data and Smaller Model Sizes
arXiv preprint
Read on: May 15, 2024
View Original PaperDistilling Step-by-Step
1. Core Idea
This paper introduces Distilling Step-by-Step, a new training framework that leverages rationales generated by a large language model (PaLM-540B) as additional supervision to train smaller models (T5) more efficiently, achieving better performance with smaller model sizes and significantly fewer training examples.
2. Key Contributions
- Proposed Distilling Step-by-Step, a novel multi-task framework that uses rationales from a large model (PaLM-540B) to train smaller models (T5).
- Demonstrated significant improvements on popular NLP benchmarks (e-SNLI, ANLI, CommonsenseQA, and SVAMP) compared to standard finetuning and distillation.
- Showed that their method can drastically reduce the required model size and training data, outperforming 540B PaLM using a 770M T5 model on ANLI with only 80% of data.
- Provided comprehensive ablation studies to validate the impact of rationale quality and multi-task training strategies.
3. Method Summary
The method involves two steps:
- Generating Rationales: Given an unlabeled dataset, the large model (PaLM-540B) is prompted using Chain-of-Thought (CoT) examples to generate two outputs:
- A rationale (called "r hat i")
- A predicted label (called "y hat i")
-
Multi-Task Training: A smaller model (T5-Base, T5-Large, or T5-XXL) is then trained using a combined loss function:
Total Loss = Label Loss + lambda × Rationale Loss
where:
- The Label Loss is the cross-entropy loss used to train the model to predict the correct label.
- The Rationale Loss is the cross-entropy loss used to train the model to also generate the reasoning steps.
4. Experiments & Results
Datasets used:
- e-SNLI (Natural Language Inference)
- ANLI (Adversarial NLI)
- CommonsenseQA (CQA, Question Answering)
- SVAMP (Arithmetic Math Problems)
Key Metrics:
- Accuracy on test sets
Main Results:
- Distilling Step-by-Step outperformed baseline methods (Standard Finetuning, Standard Distillation, Few-shot CoT PaLM-540B) on all benchmarks.
- Achieved comparable or better results using up to 2000× smaller models.
- On ANLI, a 770M T5-Large model outperformed PaLM-540B with 20% less labeled data compared to standard finetuning.
5. Limitations / Weaknesses
- Assumes high-quality rationales from large models (PaLM-540B); performance may degrade if rationale quality is lower (as shown with GPT-NeoX 20B).
- Requires few-shot CoT prompting examples, which introduces manual overhead.
- Computational overhead at training due to multi-task learning (though inference remains efficient).
6. Reproducibility / Code
- Implementation is provided openly on GitHub:
- https://github.com/google-research/distilling-step-by-step
- Code is clear and based on Hugging Face Transformers library.
- Requires GPU (A100 recommended) and common CUDA environments.
7. Follow-up Ideas
- Investigate rationale quality systematically:
- Rationale Filtering: Improve student performance by selecting the best rationale from multiple LLM-generated options using ensemble voting or GPT-4-based scoring.
- Zero-Shot Rationale Distillation: Replace few-shot CoT prompting with zero-shot CoT to eliminate the need for handcrafted examples, testing scalability across tasks.
- Bootstrapped Self-Training: After initial training, let the student model generate rationales for new unlabeled data and use them to continue self-training without the teacher.