espnet2.asr.decoder.whisper_decoder.OpenAIWhisperDecoder
espnet2.asr.decoder.whisper_decoder.OpenAIWhisperDecoder
class espnet2.asr.decoder.whisper_decoder.OpenAIWhisperDecoder(vocab_size: int, encoder_output_size: int, dropout_rate: float = 0.0, whisper_model: str = 'small', download_dir: str | None = None, load_origin_token_embedding=False)
Bases: AbsDecoder
, BatchScorerInterface
Transformer-based Speech-to-Text Decoder from OpenAI’s Whisper Model:
URL: https://github.com/openai/whisper
Initializes internal Module state, shared by both nn.Module and ScriptModule.
batch_score(ys: Tensor, states: List[Any], xs: Tensor) → Tuple[Tensor, List[Any]]
Score new token batch.
- Parameters:
- ys (torch.Tensor) – torch.int64 prefix tokens (n_batch, ylen).
- states (List *[*Any ]) – Scorer states for prefix tokens.
- xs (torch.Tensor) – The encoder feature that generates ys (n_batch, xlen, n_feat).
- Returns: Tuple of : batchfied scores for next token with shape of (n_batch, n_vocab) and next state list for ys.
- Return type: tuple[torch.Tensor, List[Any]]
forward(hs_pad: Tensor, hlens: Tensor, ys_in_pad: Tensor, ys_in_lens: Tensor) → Tuple[Tensor, Tensor]
Forward decoder.
Parameters:
- hs_pad – encoded memory, float32 (batch, maxlen_in, feat)
- hlens – (batch)
- ys_in_pad – input token ids, int64 (batch, maxlen_out) if input_layer == “embed” input tensor (batch, maxlen_out, #mels) in the other cases
- ys_in_lens – (batch)
Returns: tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token) : if use_output_layer is True,
olens: (batch, )
Return type: (tuple)
forward_one_step(tgt: Tensor, tgt_mask: Tensor, memory: Tensor, *, cache: List[Tensor] | None = None) → Tuple[Tensor, List[Tensor]]
Forward one step.
- Parameters:
- tgt – input token ids, int64 (batch, maxlen_out)
- tgt_mask – input token mask, (batch, maxlen_out) dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (include 1.2)
- memory – encoded memory, float32 (batch, maxlen_in, feat)
- cache – cached output list of (batch, max_time_out-1, size)
- Returns: NN output value and cache per self.decoders. y.shape` is (batch, maxlen_out, token)
- Return type: y, cache
NOTE (Shih-Lun): : cache implementation is ignored for now for simplicity & correctness
score(ys, state, x)
Score.