Source code for espnet.bin.asr_enhance

#!/usr/bin/env python3
import logging
import os
import random
import sys
from distutils.util import strtobool

import configargparse
import numpy as np

from espnet.asr.pytorch_backend.asr import enhance


# NOTE: you need this func to generate our sphinx doc
[docs]def get_parser(): parser = configargparse.ArgumentParser( description="Enhance noisy speech for speech recognition", config_file_parser_class=configargparse.YAMLConfigFileParser, formatter_class=configargparse.ArgumentDefaultsHelpFormatter, ) # general configuration parser.add("--config", is_config_file=True, help="config file path") parser.add( "--config2", is_config_file=True, help="second config file path that overwrites the settings in `--config`.", ) parser.add( "--config3", is_config_file=True, help="third config file path that overwrites the settings " "in `--config` and `--config2`.", ) parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs") parser.add_argument( "--backend", default="chainer", type=str, choices=["chainer", "pytorch"], help="Backend library", ) parser.add_argument("--debugmode", default=1, type=int, help="Debugmode") parser.add_argument("--seed", default=1, type=int, help="Random seed") parser.add_argument("--verbose", "-V", default=1, type=int, help="Verbose option") parser.add_argument( "--batchsize", default=1, type=int, help="Batch size for beam search (0: means no batch processing)", ) parser.add_argument( "--preprocess-conf", type=str, default=None, help="The configuration file for the pre-processing", ) # task related parser.add_argument( "--recog-json", type=str, help="Filename of recognition data (json)" ) # model (parameter) related parser.add_argument( "--model", type=str, required=True, help="Model file parameters to read" ) parser.add_argument( "--model-conf", type=str, default=None, help="Model config file" ) # Outputs configuration parser.add_argument( "--enh-wspecifier", type=str, default=None, help="Specify the output way for enhanced speech." "e.g. ark,scp:outdir,wav.scp", ) parser.add_argument( "--enh-filetype", type=str, default="sound", choices=["mat", "hdf5", "sound.hdf5", "sound"], help="Specify the file format for enhanced speech. " '"mat" is the matrix format in kaldi', ) parser.add_argument("--fs", type=int, default=16000, help="The sample frequency") parser.add_argument( "--keep-length", type=strtobool, default=True, help="Adjust the output length to match " "with the input for enhanced speech", ) parser.add_argument( "--image-dir", type=str, default=None, help="The directory saving the images." ) parser.add_argument( "--num-images", type=int, default=20, help="The number of images files to be saved. " "If negative, all samples are to be saved.", ) # IStft parser.add_argument( "--apply-istft", type=strtobool, default=True, help="Apply istft to the output from the network", ) parser.add_argument( "--istft-win-length", type=int, default=512, help="The window length for istft. " "This option is ignored " "if stft is found in the preprocess-conf", ) parser.add_argument( "--istft-n-shift", type=str, default=256, help="The window type for istft. " "This option is ignored " "if stft is found in the preprocess-conf", ) parser.add_argument( "--istft-window", type=str, default="hann", help="The window type for istft. " "This option is ignored " "if stft is found in the preprocess-conf", ) return parser
[docs]def main(args): parser = get_parser() args = parser.parse_args(args) # logging info if args.verbose == 1: logging.basicConfig( level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) elif args.verbose == 2: logging.basicConfig( level=logging.DEBUG, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) else: logging.basicConfig( level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) logging.warning("Skip DEBUG/INFO messages") # check CUDA_VISIBLE_DEVICES if args.ngpu > 0: cvd = os.environ.get("CUDA_VISIBLE_DEVICES") if cvd is None: logging.warning("CUDA_VISIBLE_DEVICES is not set.") elif args.ngpu != len(cvd.split(",")): logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") sys.exit(1) # TODO(kamo): support of multiple GPUs if args.ngpu > 1: logging.error("The program only supports ngpu=1.") sys.exit(1) # display PYTHONPATH logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) # seed setting random.seed(args.seed) np.random.seed(args.seed) logging.info("set random seed = %d" % args.seed) # recog logging.info("backend = " + args.backend) if args.backend == "pytorch": enhance(args) else: raise ValueError("Only pytorch is supported.")
if __name__ == "__main__": main(sys.argv[1:])