Source code for espnet.nets.chainer_backend.transformer.label_smoothing_loss

# encoding: utf-8
"""Class Declaration of Transformer's Label Smootion loss."""

import logging

import chainer
import chainer.functions as F


[docs]class LabelSmoothingLoss(chainer.Chain): """Label Smoothing Loss. Args: smoothing (float): smoothing rate (0.0 means the conventional CE). n_target_vocab (int): number of classes. normalize_length (bool): normalize loss by sequence length if True. """ def __init__(self, smoothing, n_target_vocab, normalize_length=False, ignore_id=-1): """Initialize Loss.""" super(LabelSmoothingLoss, self).__init__() self.use_label_smoothing = False if smoothing > 0.0: logging.info("Use label smoothing") self.smoothing = smoothing self.confidence = 1.0 - smoothing self.use_label_smoothing = True self.n_target_vocab = n_target_vocab self.normalize_length = normalize_length self.ignore_id = ignore_id self.acc = None
[docs] def forward(self, ys_block, ys_pad): """Forward Loss. Args: ys_block (chainer.Variable): Predicted labels. ys_pad (chainer.Variable): Target (true) labels. Returns: float: Training loss. """ # Output (all together at once for efficiency) batch, length, dims = ys_block.shape concat_logit_block = ys_block.reshape(-1, dims) # Target reshape concat_t_block = ys_pad.reshape((batch * length)) ignore_mask = concat_t_block >= 0 n_token = ignore_mask.sum() normalizer = n_token if self.normalize_length else batch if not self.use_label_smoothing: loss = F.softmax_cross_entropy(concat_logit_block, concat_t_block) loss = loss * n_token / normalizer else: log_prob = F.log_softmax(concat_logit_block) broad_ignore_mask = self.xp.broadcast_to( ignore_mask[:, None], concat_logit_block.shape ) pre_loss = ( ignore_mask * log_prob[self.xp.arange(batch * length), concat_t_block] ) loss = -F.sum(pre_loss) / normalizer label_smoothing = broad_ignore_mask * -1.0 / self.n_target_vocab * log_prob label_smoothing = F.sum(label_smoothing) / normalizer loss = self.confidence * loss + self.smoothing * label_smoothing return loss