espnet.nets.pytorch_backend.lm.default.DefaultRNNLM
espnet.nets.pytorch_backend.lm.default.DefaultRNNLM
class espnet.nets.pytorch_backend.lm.default.DefaultRNNLM(n_vocab, args)
Bases: BatchScorerInterface
, LMInterface
, Module
Default RNNLM for LMInterface Implementation.
NOTE
PyTorch seems to have memory leak when one GPU compute this after data parallel. If parallel GPUs compute this, it seems to be fine. See also https://github.com/espnet/espnet/issues/1075
Initialize class.
- Parameters:
- n_vocab (int) – The size of the vocabulary
- args (argparse.Namespace) – configurations. see py:method:add_arguments
static add_arguments(parser)
Add arguments to command line argument parser.
batch_score(ys: Tensor, states: List[Any], xs: Tensor) → Tuple[Tensor, List[Any]]
Score new token batch.
- Parameters:
- ys (torch.Tensor) – torch.int64 prefix tokens (n_batch, ylen).
- states (List *[*Any ]) – Scorer states for prefix tokens.
- xs (torch.Tensor) – The encoder feature that generates ys (n_batch, xlen, n_feat).
- Returns: Tuple of : batchfied scores for next token with shape of (n_batch, n_vocab) and next state list for ys.
- Return type: tuple[torch.Tensor, List[Any]]
final_score(state)
Score eos.
- Parameters:state – Scorer state for prefix tokens
- Returns: final score
- Return type: float
forward(x, t)
Compute LM loss value from buffer sequences.
- Parameters:
- x (torch.Tensor) – Input ids. (batch, len)
- t (torch.Tensor) – Target ids. (batch, len)
- Returns: Tuple of : loss to backward (scalar), negative log-likelihood of t: -log p(t) (scalar) and the number of elements in x (scalar)
- Return type: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Notes
The last two return values are used in perplexity: p(t)^{-n} = exp(-log p(t) / n)
load_state_dict(d)
Load state dict.
score(y, state, x)
Score new token.
- Parameters:
- y (torch.Tensor) – 1D torch.int64 prefix tokens.
- state – Scorer state for prefix tokens
- x (torch.Tensor) – 2D encoder feature that generates ys.
- Returns: Tuple of : torch.float32 scores for next token (n_vocab) and next state for ys
- Return type: tuple[torch.Tensor, Any]
state_dict()
Dump state dict.