espnet2.asr.decoder.hugging_face_transformers_decoder.HuggingFaceTransformersDecoder
espnet2.asr.decoder.hugging_face_transformers_decoder.HuggingFaceTransformersDecoder
class espnet2.asr.decoder.hugging_face_transformers_decoder.HuggingFaceTransformersDecoder(vocab_size: int, encoder_output_size: int, model_name_or_path: str, causal_lm: bool = False, prefix: str = '', postfix: str = '')
Bases: AbsDecoder
, BatchScorerInterface
Hugging Face Transformers Decoder.
- Parameters:
- encoder_output_size – dimension of encoder attention
- model_name_or_path – Hugging Face Transformers model name
Initializes internal Module state, shared by both nn.Module and ScriptModule.
add_prefix_postfix(enc_out, hlens, ys_in_pad, ys_in_lens)
batch_score(ys: Tensor, states: List[Any], xs: Tensor, speech: Tensor | None = None) → Tuple[Tensor, List[Any]]
Score new token batch (required).
- Parameters:
- ys (torch.Tensor) – torch.int64 prefix tokens (n_batch, ylen).
- states (List *[*Any ]) – Scorer states for prefix tokens.
- xs (torch.Tensor) – The encoder feature that generates ys (n_batch, xlen, n_feat).
- Returns: Tuple of : batchfied scores for next token with shape of (n_batch, n_vocab) and next state list for ys.
- Return type: tuple[torch.Tensor, List[Any]]
forward(hs_pad: Tensor, hlens: Tensor, ys_in_pad: Tensor, ys_in_lens: Tensor) → Tuple[Tensor, Tensor]
Forward decoder.
Parameters:
- hs_pad – encoded memory, float32 (batch, maxlen_in, feat)
- hlens – (batch)
- ys_in_pad – input tensor (batch, maxlen_out, #mels)
- ys_in_lens – (batch)
Returns: tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token) : if use_output_layer is True,
olens: (batch, )
Return type: (tuple)
reload_pretrained_parameters()
score(ys, state, x, speech=None)
Score new token (required).
- Parameters:
- y (torch.Tensor) – 1D torch.int64 prefix tokens.
- state – Scorer state for prefix tokens
- x (torch.Tensor) – The encoder feature that generates ys.
- Returns: Tuple of : scores for next token that has a shape of (n_vocab) and next state for ys
- Return type: tuple[torch.Tensor, Any]