Transformer model distillation
Overview
Transformer models which were pre-trained on large corpora, such as BERT/XLNet/XLM, have shown to improve the accuracy of many NLP tasks. However, such models have two distinct disadvantages - (1) model size and (2) speed, since such large models are computationally heavy.
One possible approach to overcome these cons is to use Knowledge Distillation (KD). Using this approach a large model is trained on the data set and then used to teach a much smaller and more efficient network. This is often referred to a Student-Teacher training where a teacher network adds its error to the student’s loss function, thus, helping the student network to converge to a better solution.
Knowledge Distillation
One approach is similar to the method in Hinton 2015 [1]. The loss function is modified to include a measure of distributions divergence, which can be measured using KL divergence or MSE between the logits of the student and the teacher network.
\(loss = w_s \cdot loss_{student} + w_d \cdot KL(logits_{student} / T || logits_{teacher} / T)\)
where T is a value representing temperature for softening the logits prior to applying softmax. loss_{student} is the original loss of the student network obtained during regular training. Finally, the losses are weighted.
TeacherStudentDistill
This class can be added to support for distillation in a model. To add support for distillation, the student model must include handling of training using TeacherStudentDistill class, see nlp_architect.procedures.token_tagging.do_kd_training for an example how to train a neural tagger using a transformer model using distillation.
- class
nlp_architect.nn.torch.distillation.TeacherStudentDistill(teacher_model: nlp_architect.models.TrainableModel, temperature: float = 1.0, dist_w: float = 0.1, loss_w: float = 1.0, loss_function='kl')[source] Teacher-Student knowledge distillation helper. Use this object when training a model with KD and a teacher model.
Parameters: - teacher_model (TrainableModel) – teacher model
- temperature (float, optional) – KD temperature. Defaults to 1.0.
- dist_w (float, optional) – distillation loss weight. Defaults to 0.1.
- loss_w (float, optional) – student loss weight. Defaults to 1.0.
- loss_function (str, optional) – loss function to use (kl for KLDivLoss, mse for MSELoss)
- static
add_args(parser: argparse.ArgumentParser)[source] Add KD arguments to parser
Parameters: parser (argparse.ArgumentParser) – parser
-
distill_loss(loss, student_logits, teacher_logits)[source] Add KD loss
Parameters: - loss – student loss
- student_logits – student model logits
- teacher_logits – teacher model logits
Returns: KD loss
Supported models
NeuralTagger
Useful for training taggers from Transformer models. NeuralTagger model that uses LSTM and CNN based embedders are ~3M parameters in size (~30-100x smaller than BERT models) and ~10x faster on average.
Usage:
- Train a transformer tagger using
TransformerTokenClassifieror usingnlp-train transformer_tokencommand - Train a neural tagger
Neural Taggerusing the trained transformer model and use theTeacherStudentDistillmodel that was configured with the transformer model. This can be done usingNeural Tagger’s train loop or by usingnlp-train tagger_kdcommand
Note
More models supporting distillation will be added in next releases
Pseudo Labeling
This method can be used in order to produce pseudo-labels when training the student on unlabeled examples. The pseudo-guess is produced by applying arg max on the logits of the teacher model, and results in the following loss:
where CE is Cross Entropy loss, yˆ is the predicted entity label class by the student model and yˆt is the predicted label by the teacher model.
| [1] | Distilling the Knowledge in a Neural Network: Geoffrey Hinton, Oriol Vinyals, Jeff Dean, https://arxiv.org/abs/1503.02531 |