Source code for espnet2.bin.tokenize_text

#!/usr/bin/env python3
import argparse
import logging
import sys
from collections import Counter
from pathlib import Path
from typing import List, Optional

from typeguard import typechecked

from espnet2.text.build_tokenizer import build_tokenizer
from espnet2.text.cleaner import TextCleaner
from espnet2.text.phoneme_tokenizer import g2p_choices
from espnet2.utils.types import str2bool, str_or_none
from espnet.utils.cli_utils import get_commandline_args


[docs]def field2slice(field: Optional[str]) -> slice: """Convert field string to slice. Note that field string accepts 1-based integer. Examples: >>> field2slice("1-") slice(0, None, None) >>> field2slice("1-3") slice(0, 3, None) >>> field2slice("-3") slice(None, 3, None) """ field = field.strip() try: if "-" in field: # e.g. "2-" or "2-5" or "-7" s1, s2 = field.split("-", maxsplit=1) if s1.strip() == "": s1 = None else: s1 = int(s1) if s1 == 0: raise ValueError("1-based string") if s2.strip() == "": s2 = None else: s2 = int(s2) else: # e.g. "2" s1 = int(field) s2 = s1 + 1 if s1 == 0: raise ValueError("must be 1 or more value") except ValueError: raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}") if s1 is None: slic = slice(None, s2) else: # -1 because of 1-based integer following "cut" command # e.g "1-3" -> slice(0, 3) slic = slice(s1 - 1, s2) return slic
[docs]@typechecked def tokenize( input: str, output: str, field: Optional[str], delimiter: Optional[str], token_type: str, space_symbol: str, non_linguistic_symbols: Optional[str], bpemodel: Optional[str], log_level: str, write_vocabulary: bool, vocabulary_size: int, remove_non_linguistic_symbols: bool, cutoff: int, add_symbol: List[str], cleaner: Optional[str], g2p: Optional[str], add_nonsplit_symbol: List[str], ): logging.basicConfig( level=log_level, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) if input == "-": fin = sys.stdin else: fin = Path(input).open("r", encoding="utf-8") if output == "-": fout = sys.stdout else: p = Path(output) p.parent.mkdir(parents=True, exist_ok=True) fout = p.open("w", encoding="utf-8") cleaner: TextCleaner = TextCleaner(cleaner) tokenizer = build_tokenizer( token_type=token_type, bpemodel=bpemodel, delimiter=delimiter, space_symbol=space_symbol, non_linguistic_symbols=non_linguistic_symbols, remove_non_linguistic_symbols=remove_non_linguistic_symbols, g2p_type=g2p, nonsplit_symbol=add_nonsplit_symbol, ) counter = Counter() if field is not None: field: slice = field2slice(field) for line in fin: line = line.rstrip() if field is not None: # e.g. field="2-" # uttidA hello world!! -> hello world!! tokens = line.split(delimiter) tokens = tokens[field] if delimiter is None: line = " ".join(tokens) else: line = delimiter.join(tokens) line = cleaner(line) tokens = tokenizer.text2tokens(line) if not write_vocabulary: fout.write(" ".join(tokens) + "\n") else: for t in tokens: counter[t] += 1 if not write_vocabulary: return # ======= write_vocabulary mode from here ======= # Sort by the number of occurrences in descending order # and filter lower frequency words than cutoff value words_and_counts = list( filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1])) ) # Restrict the vocabulary size if vocabulary_size > 0: if vocabulary_size < len(add_symbol): raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}") words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)] # Parse the values of --add_symbol and --add_nonsplit_symbol for symbol_and_id in add_symbol + add_nonsplit_symbol: # e.g symbol="<blank>:0" try: symbol, idx = symbol_and_id.split(":") idx = int(idx) except ValueError: raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}") symbol = symbol.strip() # e.g. idx=0 -> append as the first symbol # e.g. idx=-1 -> append as the last symbol if idx < 0: idx = len(words_and_counts) + 1 + idx words_and_counts.insert(idx, (symbol, None)) # Write words for w, c in words_and_counts: fout.write(w + "\n") # Logging total_count = sum(counter.values()) invocab_count = sum(c for w, c in words_and_counts if c is not None) logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
[docs]def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Tokenize texts", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--log_level", type=lambda x: x.upper(), default="INFO", choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), help="The verbose level of logging", ) parser.add_argument( "--input", "-i", required=True, help="Input text. - indicates sys.stdin" ) parser.add_argument( "--output", "-o", required=True, help="Output text. - indicates sys.stdout" ) parser.add_argument( "--field", "-f", help="The target columns of the input text as 1-based integer. e.g 2-", ) parser.add_argument( "--token_type", "-t", default="char", choices=["char", "bpe", "word", "phn"], help="Token type", ) parser.add_argument("--delimiter", "-d", default=None, help="The delimiter") parser.add_argument("--space_symbol", default="<space>", help="The space symbol") parser.add_argument("--bpemodel", default=None, help="The bpemodel file path") parser.add_argument( "--non_linguistic_symbols", type=str_or_none, help="non_linguistic_symbols file path", ) parser.add_argument( "--remove_non_linguistic_symbols", type=str2bool, default=False, help="Remove non-language-symbols from tokens", ) parser.add_argument( "--cleaner", type=str_or_none, choices=[ None, "tacotron", "jaconv", "vietnamese", "korean_cleaner", "whisper_en", "whisper_basic", ], default=None, help="Apply text cleaning", ) parser.add_argument( "--g2p", type=str_or_none, choices=g2p_choices, default=None, help="Specify g2p method if --token_type=phn", ) group = parser.add_argument_group("write_vocabulary mode related") group.add_argument( "--write_vocabulary", type=str2bool, default=False, help="Write tokens list instead of tokenized text per line", ) group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size") group.add_argument( "--cutoff", default=0, type=int, help="cut-off frequency used for write-vocabulary mode", ) group.add_argument( "--add_symbol", type=str, default=[], action="append", help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'", ) group.add_argument( "--add_nonsplit_symbol", type=str, default=[], action="append", help="Append symbol that is nonsplit e.g. --add_nonsplit_symbol '<sc>:2", ) return parser
[docs]def main(cmd=None): print(get_commandline_args(), file=sys.stderr) parser = get_parser() args = parser.parse_args(cmd) kwargs = vars(args) tokenize(**kwargs)
if __name__ == "__main__": main()