"""Custom decoder definition for Transducer model."""
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from espnet.nets.pytorch_backend.transducer.blocks import build_blocks
from espnet.nets.pytorch_backend.transducer.utils import (
check_batch_states,
check_state,
pad_sequence,
)
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.transducer_decoder_interface import (
ExtendedHypothesis,
Hypothesis,
TransducerDecoderInterface,
)
[docs]class CustomDecoder(TransducerDecoderInterface, torch.nn.Module):
"""Custom decoder module for Transducer model.
Args:
odim: Output dimension.
dec_arch: Decoder block architecture (type and parameters).
input_layer: Input layer type.
repeat_block: Number of times dec_arch is repeated.
joint_activation_type: Type of activation for joint network.
positional_encoding_type: Positional encoding type.
positionwise_layer_type: Positionwise layer type.
positionwise_activation_type: Positionwise activation type.
input_layer_dropout_rate: Dropout rate for input layer.
blank_id: Blank symbol ID.
"""
def __init__(
self,
odim: int,
dec_arch: List,
input_layer: str = "embed",
repeat_block: int = 0,
joint_activation_type: str = "tanh",
positional_encoding_type: str = "abs_pos",
positionwise_layer_type: str = "linear",
positionwise_activation_type: str = "relu",
input_layer_dropout_rate: float = 0.0,
blank_id: int = 0,
):
"""Construct a CustomDecoder object."""
torch.nn.Module.__init__(self)
self.embed, self.decoders, ddim, _ = build_blocks(
"decoder",
odim,
input_layer,
dec_arch,
repeat_block=repeat_block,
positional_encoding_type=positional_encoding_type,
positionwise_layer_type=positionwise_layer_type,
positionwise_activation_type=positionwise_activation_type,
input_layer_dropout_rate=input_layer_dropout_rate,
padding_idx=blank_id,
)
self.after_norm = LayerNorm(ddim)
self.dlayers = len(self.decoders)
self.dunits = ddim
self.odim = odim
self.blank_id = blank_id
[docs] def set_device(self, device: torch.device):
"""Set GPU device to use.
Args:
device: Device ID.
"""
self.device = device
[docs] def init_state(
self,
batch_size: Optional[int] = None,
) -> List[Optional[torch.Tensor]]:
"""Initialize decoder states.
Args:
batch_size: Batch size.
Returns:
state: Initial decoder hidden states. [N x None]
"""
state = [None] * self.dlayers
return state
[docs] def forward(
self, dec_input: torch.Tensor, dec_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode label ID sequences.
Args:
dec_input: Label ID sequences. (B, U)
dec_mask: Label mask sequences. (B, U)
Return:
dec_output: Decoder output sequences. (B, U, D_dec)
dec_output_mask: Mask of decoder output sequences. (B, U)
"""
dec_input = self.embed(dec_input)
dec_output, dec_mask = self.decoders(dec_input, dec_mask)
dec_output = self.after_norm(dec_output)
return dec_output, dec_mask
[docs] def score(
self, hyp: Hypothesis, cache: Dict[str, Any]
) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]:
"""One-step forward hypothesis.
Args:
hyp: Hypothesis.
cache: Pairs of (dec_out, dec_state) for each label sequence. (key)
Returns:
dec_out: Decoder output sequence. (1, D_dec)
dec_state: Decoder hidden states. [N x (1, U, D_dec)]
lm_label: Label ID for LM. (1,)
"""
labels = torch.tensor([hyp.yseq], device=self.device)
lm_label = labels[:, -1]
str_labels = "_".join(list(map(str, hyp.yseq)))
if str_labels in cache:
dec_out, dec_state = cache[str_labels]
else:
dec_out_mask = subsequent_mask(len(hyp.yseq)).unsqueeze_(0)
new_state = check_state(hyp.dec_state, (labels.size(1) - 1), self.blank_id)
dec_out = self.embed(labels)
dec_state = []
for s, decoder in zip(new_state, self.decoders):
dec_out, dec_out_mask = decoder(dec_out, dec_out_mask, cache=s)
dec_state.append(dec_out)
dec_out = self.after_norm(dec_out[:, -1])
cache[str_labels] = (dec_out, dec_state)
return dec_out[0], dec_state, lm_label
[docs] def batch_score(
self,
hyps: Union[List[Hypothesis], List[ExtendedHypothesis]],
dec_states: List[Optional[torch.Tensor]],
cache: Dict[str, Any],
use_lm: bool,
) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]:
"""One-step forward hypotheses.
Args:
hyps: Hypotheses.
dec_states: Decoder hidden states. [N x (B, U, D_dec)]
cache: Pairs of (h_dec, dec_states) for each label sequences. (keys)
use_lm: Whether to compute label ID sequences for LM.
Returns:
dec_out: Decoder output sequences. (B, D_dec)
dec_states: Decoder hidden states. [N x (B, U, D_dec)]
lm_labels: Label ID sequences for LM. (B,)
"""
final_batch = len(hyps)
process = []
done = [None] * final_batch
for i, hyp in enumerate(hyps):
str_labels = "_".join(list(map(str, hyp.yseq)))
if str_labels in cache:
done[i] = cache[str_labels]
else:
process.append((str_labels, hyp.yseq, hyp.dec_state))
if process:
labels = pad_sequence([p[1] for p in process], self.blank_id)
labels = torch.LongTensor(labels, device=self.device)
p_dec_states = self.create_batch_states(
self.init_state(),
[p[2] for p in process],
labels,
)
dec_out = self.embed(labels)
dec_out_mask = (
subsequent_mask(labels.size(-1))
.unsqueeze_(0)
.expand(len(process), -1, -1)
)
new_states = []
for s, decoder in zip(p_dec_states, self.decoders):
dec_out, dec_out_mask = decoder(dec_out, dec_out_mask, cache=s)
new_states.append(dec_out)
dec_out = self.after_norm(dec_out[:, -1])
j = 0
for i in range(final_batch):
if done[i] is None:
state = self.select_state(new_states, j)
done[i] = (dec_out[j], state)
cache[process[j][0]] = (dec_out[j], state)
j += 1
dec_out = torch.stack([d[0] for d in done])
dec_states = self.create_batch_states(
dec_states, [d[1] for d in done], [[0] + h.yseq for h in hyps]
)
if use_lm:
lm_labels = torch.LongTensor(
[hyp.yseq[-1] for hyp in hyps], device=self.device
)
return dec_out, dec_states, lm_labels
return dec_out, dec_states, None
[docs] def select_state(
self, states: List[Optional[torch.Tensor]], idx: int
) -> List[Optional[torch.Tensor]]:
"""Get specified ID state from decoder hidden states.
Args:
states: Decoder hidden states. [N x (B, U, D_dec)]
idx: State ID to extract.
Returns:
state_idx: Decoder hidden state for given ID. [N x (1, U, D_dec)]
"""
if states[0] is None:
return states
state_idx = [states[layer][idx] for layer in range(self.dlayers)]
return state_idx
[docs] def create_batch_states(
self,
states: List[Optional[torch.Tensor]],
new_states: List[Optional[torch.Tensor]],
check_list: List[List[int]],
) -> List[Optional[torch.Tensor]]:
"""Create decoder hidden states sequences.
Args:
states: Decoder hidden states. [N x (B, U, D_dec)]
new_states: Decoder hidden states. [B x [N x (1, U, D_dec)]]
check_list: Label ID sequences.
Returns:
states: New decoder hidden states. [N x (B, U, D_dec)]
"""
if new_states[0][0] is None:
return states
max_len = max(len(elem) for elem in check_list) - 1
for layer in range(self.dlayers):
states[layer] = check_batch_states(
[s[layer] for s in new_states], max_len, self.blank_id
)
return states