RescoreBERT: Using BERT models to improve ASR rescoring
Knowledge distillation and discriminative training enable efficient use of a BERT-based model to rescore automatic-speech-recognition hypotheses.
When someone speaks to a voice agent like Alexa, an automatic speech recognition (ASR) model converts the speech to text. Typically, the core ASR model is trained on limited data, which means that it can struggle with rare words and phrases. So the ASR model’s hypotheses usually pass to a language model — a model that encodes the probabilities of sequences of words — trained on a much larger body of texts. The language model reranks the hypotheses, with the goal of improving ASR accuracy.
In natural-language processing, one of the most widely used language models is BERT (bidirectional encoder representations from Transformers). To use BERT as a rescoring model, one typically masks each input token and computes its log-likelihood from the rest of the input, then sums those scores to produce a total score called PLL (pseudo log-likelihood). However, this computation is very expensive, which makes it impractical for real-time ASR. For rescoring, most ASR models use more efficient long-short-term-memory (LSTM) language models.
At this year’s International Conference on Acoustics, Speech, and Signal Processing (ICASSP), we presented a paper in which we propose a new model, namely RescoreBERT, that leverages BERT’s power for second-pass rescoring.
In our experiments, RescoreBERT reduced an ASR model’s error rate by up to 13% relative to a traditional LSTM-based rescoring model. At the same time, thanks to a combination of knowledge distillation and discriminative training, it remains efficient enough for commercial deployment. In fact, we recently partnered with the Alexa team working on the Alexa Teacher Model — a large, pretrained, multilingual model with billions of parameters that encodes language as well as salient patterns of interactions with Alexa — and deployed RescoreBERT to production to delight Alexa customers.
To get a sense for the value of rescoring, suppose that an ASR model outputs these hypotheses, from more to less likely: (a) “is fishing the opposite of fusion”, (b) “is fission the opposite of fusion”, and (c) “is fission the opposite of fashion”. Without second-pass rescoring, ASR would give an incorrect output: “is fishing the opposite of fusion”. If the second-pass language model does its job well, it should give priority to the hypothesis “is fission the opposite of fusion” and correctly rerank the hypotheses. A language model trained from scratch on a limited set of data will often struggle with rare words such as “fission”.
To reduce the computational expense of computing PLL scores, we adapt previous work from Amazon and pass the BERT model’s output through a neural network trained to mimic the PLL scores assigned by a larger, “teacher” model. We name this method MLM (masked language model) distillation, because the distilled model is trained to match the teacher’s predictions of masked inputs.
The score output by the distilled model is interpolated with the original score to produce a final score. By distilling PLL scores from a large BERT model to a much smaller BERT model, this approach reduces latency.
Because the first- and second-pass scores are linearly interpolated, it’s not enough for the rescoring model to assign the correct hypothesis a better (in this case, lower) score; the interpolated score for the correct hypothesis also has to be the lowest among all hypotheses.
As a result, it would be beneficial to account for first-pass scores when training the second-pass rescoring model. However, the MLM distillation aims to distill the PLL scores and hence does not account for the first-pass scores. To account for the first-pass scores, we apply discriminative training after MLM distillation.
Specifically, we train RescoreBERT with the objective that, if one uses the linearly interpolated score between the first-pass and second-pass scores to rerank the hypotheses, it will minimize ASR errors. To capture this objective, previous research has used the loss function MWER (minimum word error rate), minimizing the expected number of word errors predicted from ASR hypothesis scores.
We introduce a new loss function, named MWED (matching word error distribution). This loss function matches the distribution of the hypothesis scores to the distribution of word errors for individual hypotheses. We show that MWED is a strong alternative to the standard MWER, improving performance in English, although it slightly degrades performance in Japanese.
Finally, to demonstrate the advantage of discriminative training, we show that while BERT trained with MLM distillation can improve WER by 3%-6% relative to LSTM, RescoreBERT, trained with a discriminative objective, can improve it by 7%-13% on the same test sets.