espnet.nets.pytorch_backend.lm.default.ClassifierWithState
espnet.nets.pytorch_backend.lm.default.ClassifierWithState
class espnet.nets.pytorch_backend.lm.default.ClassifierWithState(predictor, lossfun=CrossEntropyLoss(), label_key=-1)
Bases: Module
A wrapper for pytorch RNNLM.
Initialize class.
:param torch.nn.Module predictor : The RNNLM :param function lossfun : The loss function to use :param int/str label_key :
buff_predict(state, x, n)
Predict new tokens from buffered inputs.
final(state, index=None)
Predict final log probabilities for given state using the predictor.
- Parameters:state – The state
:return The final log probabilities :rtype torch.Tensor
forward(state, *args, **kwargs)
Compute the loss value for an input and label pair.
Notes
It also computes accuracy and stores it to the attribute. When label_key
is int
, the corresponding element in args
is treated as ground truth labels. And when it is str
, the element in kwargs
is used. The all elements of args
and kwargs
except the groundtruth labels are features. It feeds features to the predictor and compare the result with ground truth labels.
:param torch.Tensor state : the LM state :param list[torch.Tensor] args : Input minibatch :param dict[torch.Tensor] kwargs : Input minibatch :return loss value :rtype torch.Tensor
predict(state, x)
Predict log probabilities for given state and input x using the predictor.
:param torch.Tensor state : The current state :param torch.Tensor x : The input :return a tuple (new state, log prob vector) :rtype (torch.Tensor, torch.Tensor)