Source code for espnet2.asr_transducer.encoder.validation

"""Set of methods to validate encoder architecture."""

from typing import Any, Dict, List, Tuple

from espnet2.asr_transducer.utils import get_convinput_module_parameters


[docs]def validate_block_arguments( configuration: Dict[str, Any], block_id: int, previous_block_output: int, ) -> Tuple[int, int]: """Validate block arguments. Args: configuration: Architecture configuration. block_id: Block ID. previous_block_output: Previous block output size. Returns: input_size: Block input size. output_size: Block output size. """ block_type = configuration.get("block_type") if block_type is None: raise ValueError( "Block %d in encoder doesn't have a type assigned. " % block_id ) if block_type in ["branchformer", "conformer", "ebranchformer"]: if configuration.get("linear_size") is None: raise ValueError( "Missing 'linear_size' argument for X-former block (ID: %d)" % block_id ) if configuration.get("conv_mod_kernel_size") is None: raise ValueError( "Missing 'conv_mod_kernel_size' argument for X-former block (ID: %d)" % block_id ) input_size = configuration.get("hidden_size") output_size = configuration.get("hidden_size") elif block_type == "conv1d": output_size = configuration.get("output_size") if output_size is None: raise ValueError( "Missing 'output_size' argument for Conv1d block (ID: %d)" % block_id ) if configuration.get("kernel_size") is None: raise ValueError( "Missing 'kernel_size' argument for Conv1d block (ID: %d)" % block_id ) input_size = configuration["input_size"] = previous_block_output else: raise ValueError("Block type: %s is not supported." % block_type) return input_size, output_size
[docs]def validate_input_block( configuration: Dict[str, Any], body_first_conf: Dict[str, Any], input_size: int ) -> int: """Validate input block. Args: configuration: Encoder input block configuration. body_first_conf: Encoder first body block configuration. input_size: Encoder input block input size. Return: output_size: Encoder input block output size. """ vgg_like = configuration.get("vgg_like", False) next_block_type = body_first_conf.get("block_type") allowed_next_block_type = ["branchformer", "conformer", "conv1d", "ebranchformer"] if next_block_type is None or (next_block_type not in allowed_next_block_type): return -1 if configuration.get("subsampling_factor") is None: configuration["subsampling_factor"] = 4 sub_factor = configuration["subsampling_factor"] if vgg_like: conv_size = configuration.get("conv_size", (64, 128)) if isinstance(conv_size, int): conv_size = (conv_size, conv_size) if sub_factor not in [4, 6]: raise ValueError( "VGG2L input module only support subsampling factor of 4 and 6." ) else: conv_size = configuration.get("conv_size", None) if isinstance(conv_size, tuple): conv_size = conv_size[0] if sub_factor not in [2, 4, 6]: raise ValueError( "Conv2D input module only support subsampling factor of 2, 4 and 6." ) if next_block_type == "conv1d": if vgg_like: _, output_size = get_convinput_module_parameters( input_size, conv_size[1], sub_factor, is_vgg=True ) else: if conv_size is None: conv_size = body_first_conf.get("output_size", 64) _, output_size = get_convinput_module_parameters( input_size, conv_size, sub_factor, is_vgg=False ) configuration["output_size"] = None else: output_size = body_first_conf.get("hidden_size") if conv_size is None: conv_size = output_size configuration["output_size"] = output_size configuration["conv_size"] = conv_size configuration["vgg_like"] = vgg_like return output_size
[docs]def validate_architecture( input_conf: Dict[str, Any], body_conf: List[Dict[str, Any]], input_size: int ) -> Tuple[int, int]: """Validate specified architecture is valid. Args: input_conf: Encoder input block configuration. body_conf: Encoder body blocks configuration. input_size: Encoder input size. Returns: input_block_osize: Encoder input block output size. : Encoder body block output size. """ input_block_osize = validate_input_block(input_conf, body_conf[0], input_size) cmp_io = [] for i, b in enumerate(body_conf): _io = validate_block_arguments( b, (i + 1), input_block_osize if i == 0 else cmp_io[i - 1][1] ) cmp_io.append(_io) for i in range(1, len(cmp_io)): if cmp_io[(i - 1)][1] != cmp_io[i][0]: raise ValueError( "Output/Input mismatch between blocks %d and %d" " in the encoder body." % ((i - 1), i) ) return input_block_osize, cmp_io[-1][1]