Source code for espnet2.asr_transducer.encoder.encoder

"""Encoder for Transducer model."""

from typing import Any, Dict, List, Tuple

import torch
from typeguard import typechecked

from espnet2.asr_transducer.encoder.building import (
    build_body_blocks,
    build_input_block,
    build_main_parameters,
    build_positional_encoding,
)
from espnet2.asr_transducer.encoder.validation import validate_architecture
from espnet2.asr_transducer.utils import (
    TooShortUttError,
    check_short_utt,
    make_chunk_mask,
    make_source_mask,
)


[docs]class Encoder(torch.nn.Module): """Encoder module definition. Args: input_size: Input size. body_conf: Encoder body configuration. input_conf: Encoder input configuration. main_conf: Encoder main configuration. """ @typechecked def __init__( self, input_size: int, body_conf: List[Dict[str, Any]], input_conf: Dict[str, Any] = {}, main_conf: Dict[str, Any] = {}, ) -> None: """Construct an Encoder object.""" super().__init__() embed_size, output_size = validate_architecture( input_conf, body_conf, input_size ) main_params = build_main_parameters(**main_conf) self.embed = build_input_block(input_size, input_conf) self.pos_enc = build_positional_encoding(embed_size, main_params) self.encoders = build_body_blocks(body_conf, main_params, output_size) self.output_size = output_size self.dynamic_chunk_training = main_params["dynamic_chunk_training"] self.short_chunk_threshold = main_params["short_chunk_threshold"] self.short_chunk_size = main_params["short_chunk_size"] self.num_left_chunks = main_params["num_left_chunks"]
[docs] def reset_cache(self, left_context: int, device: torch.device) -> None: """Initialize/Reset encoder cache for streaming. Args: left_context: Number of previous frames (AFTER subsampling) the attention module can see in current chunk. device: Device ID. """ return self.encoders.reset_streaming_cache(left_context, device)
[docs] def forward( self, x: torch.Tensor, x_len: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Encode input sequences. Args: x: Encoder input features. (B, T_in, F) x_len: Encoder input features lengths. (B,) Returns: x: Encoder outputs. (B, T_out, D_enc) x_len: Encoder outputs lenghts. (B,) """ short_status, limit_size = check_short_utt( self.embed.subsampling_factor, x.size(1) ) if short_status: raise TooShortUttError( f"has {x.size(1)} frames and is too short for subsampling " + f"(it needs more than {limit_size} frames), return empty results", x.size(1), limit_size, ) mask = make_source_mask(x_len) x, mask = self.embed(x, mask) pos_enc = self.pos_enc(x) if self.dynamic_chunk_training: max_len = x.size(1) chunk_size = torch.randint(1, max_len, (1,)).item() if chunk_size > (max_len * self.short_chunk_threshold): chunk_size = max_len else: chunk_size = (chunk_size % self.short_chunk_size) + 1 chunk_mask = make_chunk_mask( x.size(1), chunk_size, num_left_chunks=self.num_left_chunks, device=x.device, ) else: chunk_mask = None x = self.encoders( x, pos_enc, mask, chunk_mask=chunk_mask, ) return x, mask.eq(0).sum(1)
[docs] def chunk_forward( self, x: torch.Tensor, x_len: torch.Tensor, processed_frames: torch.tensor, left_context: int = 32, ) -> torch.Tensor: """Encode input sequences as chunks. Args: x: Encoder input features. (1, T_in, F) x_len: Encoder input features lengths. (1,) processed_frames: Number of frames already seen. left_context: Number of previous frames (AFTER subsampling) the attention module can see in current chunk. Returns: x: Encoder outputs. (B, T_out, D_enc) """ mask = make_source_mask(x_len) x, mask = self.embed(x, mask) x = x[:, 1:-1, :] mask = mask[:, 1:-1] pos_enc = self.pos_enc(x, left_context=left_context) processed_mask = ( torch.arange(left_context, device=x.device).view(1, left_context).flip(1) ) processed_mask = processed_mask >= processed_frames mask = torch.cat([processed_mask, mask], dim=1) x = self.encoders.chunk_forward( x, pos_enc, mask, left_context=left_context, ) return x