Using teacher knowledge at inference time to enhance student model
New method improves the state of the art in knowledge distillation by leveraging a knowledge base of teacher predictions.
Knowledge distillation (KD) is one of the most effective ways to deploy large-scale language models in environments where low latency is essential. KD involves transferring the knowledge contained in large-scale models (“teachers”) to smaller models (“students”).
Because of their size, student models are typically more efficient than teacher models, but they’re often less powerful. In a paper we presented at this year’s meeting of the Association for Computational Linguistics (ACL), we proposed retrieval-augmented knowledge distillation (ReAugKD), a framework that leverages the power of teacher models to improve student models’ performance, with a minimal latency overhead.
Specifically, we use data representations (embeddings) and predictions produced by the teacher model on previous inputs — which can be stored in a lookup table — to guide the student model’s predictions for similar inputs. In principle, however, the same approach could be adapted for any task-specific external knowledge.
To evaluate ReAugKD, we compared its performance to that of ten prior models on six natural-language-processing tasks, including paraphrasing, natural-language inference, and question answering. On five of the tasks, ReAugKD was the top performer, and on the sixth, it ranked second. On average, it establishes a new state of the art for the benchmark, while incurring a latency overhead of only 3%.
ReAugKD uses two-step training procedure. In the first step, we begin with a teacher model that has been fine-tuned for a specific downstream task. Then we add a linear-projection layer on top of the model’s encoder, to project the encoder’s embeddings — or vector representations of the input data — to the same dimensions as the student model’s encoder. To fine-tune the parameters of the linear-projection layer, we use a supervised contrastive loss, which uses training examples with the same labels as positives and contrasts them with negatives sampled randomly from the remainder of the batch.
In the second step, we generate (resized) teacher embeddings and teacher predictions for the input data we’ll use to train the student. Then we create a similarity matrix for the teacher embeddings, which measures the similarity between the embedding of each input and those of all the other inputs.
To train the student model, we create a similarity matrix for the student embeddings and the teacher embeddings and use a loss function that minimizes the Kullback–Leibler divergence between the teacher-teacher similarity distribution and the teacher-student similarity distribution. Essentially, this ensures that at inference time, when we’re searching our knowledge base for teacher embeddings similar to that of the student’s current input, both the student and the teacher are using the same notion of similarity.
Our loss function also has a term that uses the popular cross-entropy loss to calculate the divergence between the student’s predictions and the teacher’s predictions.
Experiments and results
In tests, we used ReAugKD to distill the 12-layer BERT-Base model into a six-layer BERT model, evaluating performance on six datasets of the GLUE benchmark. Our method achieves start-of-the-art results on five out of the six datasets, with an average improvement of 0.42% over the previous best KD method and improvements of 1.37% and 1.43% on two of the benchmark tasks.
The version of ReAugKD that uses knowledge base retrieval also exhibits an improvement of 0.45% over ReAugKD without retrieval, verifying the benefit of retrieval augmentation in our approach.