← Back to Reading Log

Distilling Step-by-Step! Outperforming Larger Language Models with Less Training Data and Smaller Model Sizes

Hsieh, C.-Y., Li, C.-L., Yeh, C.-K., Nakhost, H., Fujii, Y., Ratner, A., Krishna, R., Lee, C.-Y., & Pfister, T.

arXiv preprint

Read on: May 15, 2024

View Original Paper

Distilling Step-by-Step

1. Core Idea

This paper introduces Distilling Step-by-Step, a new training framework that leverages rationales generated by a large language model (PaLM-540B) as additional supervision to train smaller models (T5) more efficiently, achieving better performance with smaller model sizes and significantly fewer training examples.


2. Key Contributions

  • Proposed Distilling Step-by-Step, a novel multi-task framework that uses rationales from a large model (PaLM-540B) to train smaller models (T5).
  • Demonstrated significant improvements on popular NLP benchmarks (e-SNLI, ANLI, CommonsenseQA, and SVAMP) compared to standard finetuning and distillation.
  • Showed that their method can drastically reduce the required model size and training data, outperforming 540B PaLM using a 770M T5 model on ANLI with only 80% of data.
  • Provided comprehensive ablation studies to validate the impact of rationale quality and multi-task training strategies.

3. Method Summary

The method involves two steps:

  1. Generating Rationales: Given an unlabeled dataset, the large model (PaLM-540B) is prompted using Chain-of-Thought (CoT) examples to generate two outputs:
    • A rationale (called "r hat i")
    • A predicted label (called "y hat i")
  2. Multi-Task Training: A smaller model (T5-Base, T5-Large, or T5-XXL) is then trained using a combined loss function:

    Total Loss = Label Loss + lambda × Rationale Loss

where:

  • The Label Loss is the cross-entropy loss used to train the model to predict the correct label.
  • The Rationale Loss is the cross-entropy loss used to train the model to also generate the reasoning steps.

4. Experiments & Results

Datasets used:

  • e-SNLI (Natural Language Inference)
  • ANLI (Adversarial NLI)
  • CommonsenseQA (CQA, Question Answering)
  • SVAMP (Arithmetic Math Problems)

Key Metrics:

  • Accuracy on test sets

Main Results:

  • Distilling Step-by-Step outperformed baseline methods (Standard Finetuning, Standard Distillation, Few-shot CoT PaLM-540B) on all benchmarks.
  • Achieved comparable or better results using up to 2000× smaller models.
  • On ANLI, a 770M T5-Large model outperformed PaLM-540B with 20% less labeled data compared to standard finetuning.

5. Limitations / Weaknesses

  • Assumes high-quality rationales from large models (PaLM-540B); performance may degrade if rationale quality is lower (as shown with GPT-NeoX 20B).
  • Requires few-shot CoT prompting examples, which introduces manual overhead.
  • Computational overhead at training due to multi-task learning (though inference remains efficient).

6. Reproducibility / Code

  • Implementation is provided openly on GitHub:
    • https://github.com/google-research/distilling-step-by-step
  • Code is clear and based on Hugging Face Transformers library.
  • Requires GPU (A100 recommended) and common CUDA environments.

7. Follow-up Ideas

  • Investigate rationale quality systematically:
    • Rationale Filtering: Improve student performance by selecting the best rationale from multiple LLM-generated options using ensemble voting or GPT-4-based scoring.
    • Zero-Shot Rationale Distillation: Replace few-shot CoT prompting with zero-shot CoT to eliminate the need for handcrafted examples, testing scalability across tasks.
    • Bootstrapped Self-Training: After initial training, let the student model generate rationales for new unlabeled data and use them to continue self-training without the teacher.