Source code for espnet2.text.whisper_tokenizer

import copy
import os
from typing import Iterable, List, Optional

from typeguard import typechecked

from espnet2.text.abs_tokenizer import AbsTokenizer

    "noinfo": "english",  # default, English
    "ca": "catalan",
    "cs": "czech",
    "cy": "welsh",
    "de": "german",
    "en": "english",
    "eu": "basque",
    "es": "spanish",
    "fa": "persian",
    "fr": "french",
    "it": "italian",
    "ja": "japanese",
    "jpn": "japanese",
    "ko": "korean",
    "kr": "korean",
    "nl": "dutch",
    "pl": "polish",
    "pt": "portuguese",
    "ru": "russian",
    "tt": "tatar",
    "zh": "chinese",
    "zh-TW": "chinese",
    "zh-CN": "chinese",
    "zh-HK": "chinese",
dirname = os.path.dirname(__file__)

[docs]class OpenAIWhisperTokenizer(AbsTokenizer): @typechecked def __init__( self, model_type: str, language: str = "en", task: str = "transcribe", sot: bool = False, speaker_change_symbol: str = "<sc>", added_tokens_txt: Optional[str] = None, ): try: import whisper.tokenizer except Exception as e: print("Error: whisper is not properly installed.") print( "Please install whisper with: cd ${MAIN_ROOT}/tools && " "./installers/" ) raise e self.model = model_type self.language = LANGUAGES_CODE_MAPPING.get(language) if self.language is None: raise ValueError(f"language: {self.language} unsupported for Whisper model") self.task = task if self.task not in ["transcribe", "translate"]: raise ValueError(f"task: {self.task} unsupported for Whisper model") if model_type == "whisper_en": self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual=False) elif model_type == "whisper_multilingual": self.tokenizer = whisper.tokenizer.get_tokenizer( multilingual=True, language=self.language, task=self.task ) if added_tokens_txt is not None: _added_tokens = [] with open(added_tokens_txt) as f: lines = f.readlines() for line in lines: _added_tokens.append(line.rstrip()) self.tokenizer.tokenizer.add_tokens(_added_tokens) else: raise ValueError("tokenizer unsupported:", model_type) self.tokenizer = copy.deepcopy(self.tokenizer) # Whisper uses discrete tokens (20ms) to encode timestamp timestamps = [f"<|{i*0.02:.2f}|>" for i in range(0, 1501)] sc = [speaker_change_symbol] if sot else [] special_tokens = ( self.tokenizer.tokenizer.additional_special_tokens + timestamps + sc ) self.tokenizer.tokenizer.add_special_tokens( dict(additional_special_tokens=special_tokens) ) def __repr__(self): return ( f"{self.__class__.__name__}(model_type={self.model}, " f"language={self.language})" )
[docs] def text2tokens(self, line: str) -> List[str]: return self.tokenizer.tokenizer.tokenize(line, add_special_tokens=False)
[docs] def tokens2text(self, tokens: Iterable[str]) -> str: return self.tokenizer.tokenizer.convert_tokens_to_string(tokens)