Source code for espnet2.main_funcs.pack_funcs

import os
import sys
import tarfile
import zipfile
from datetime import datetime
from io import BytesIO, TextIOWrapper
from pathlib import Path
from typing import Dict, Iterable, Optional, Union

import yaml


[docs]class Archiver: def __init__(self, file, mode="r"): if Path(file).suffix == ".tar": self.type = "tar" elif Path(file).suffix == ".tgz" or Path(file).suffixes == [".tar", ".gz"]: self.type = "tar" if mode == "w": mode = "w:gz" elif Path(file).suffix == ".tbz2" or Path(file).suffixes == [".tar", ".bz2"]: self.type = "tar" if mode == "w": mode = "w:bz2" elif Path(file).suffix == ".txz" or Path(file).suffixes == [".tar", ".xz"]: self.type = "tar" if mode == "w": mode = "w:xz" elif Path(file).suffix == ".zip": self.type = "zip" else: raise ValueError(f"Cannot detect archive format: type={file}") if self.type == "tar": self.fopen = tarfile.open(file, mode=mode) elif self.type == "zip": self.fopen = zipfile.ZipFile(file, mode=mode) else: raise ValueError(f"Not supported: type={type}") def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.fopen.close()
[docs] def close(self): self.fopen.close()
def __iter__(self): if self.type == "tar": return iter(self.fopen) elif self.type == "zip": return iter(self.fopen.infolist()) else: raise ValueError(f"Not supported: type={self.type}")
[docs] def add(self, filename, arcname=None, recursive: bool = True): if arcname is not None: print(f"adding: {arcname}") else: print(f"adding: {filename}") if recursive and Path(filename).is_dir(): for f in Path(filename).glob("**/*"): if f.is_dir(): continue if arcname is not None: _arcname = Path(arcname) / f else: _arcname = None self.add(f, _arcname) return if self.type == "tar": return self.fopen.add(filename, arcname) elif self.type == "zip": return self.fopen.write(filename, arcname) else: raise ValueError(f"Not supported: type={self.type}")
[docs] def addfile(self, info, fileobj): print(f"adding: {self.get_name_from_info(info)}") if self.type == "tar": return self.fopen.addfile(info, fileobj) elif self.type == "zip": return self.fopen.writestr(info, fileobj.read()) else: raise ValueError(f"Not supported: type={self.type}")
[docs] def generate_info(self, name, size) -> Union[tarfile.TarInfo, zipfile.ZipInfo]: """Generate TarInfo using system information""" if self.type == "tar": tarinfo = tarfile.TarInfo(str(name)) if os.name == "posix": tarinfo.gid = os.getgid() tarinfo.uid = os.getuid() tarinfo.mtime = datetime.now().timestamp() tarinfo.size = size # Keep mode as default return tarinfo elif self.type == "zip": zipinfo = zipfile.ZipInfo(str(name), datetime.now().timetuple()[:6]) zipinfo.file_size = size return zipinfo else: raise ValueError(f"Not supported: type={self.type}")
[docs] def get_name_from_info(self, info): if self.type == "tar": assert isinstance(info, tarfile.TarInfo), type(info) return info.name elif self.type == "zip": assert isinstance(info, zipfile.ZipInfo), type(info) return info.filename else: raise ValueError(f"Not supported: type={self.type}")
[docs] def extract(self, info, path=None): if self.type == "tar": return self.fopen.extract(info, path) elif self.type == "zip": return self.fopen.extract(info, path) else: raise ValueError(f"Not supported: type={self.type}")
[docs] def extractfile(self, info, mode="r"): if self.type == "tar": f = self.fopen.extractfile(info) if mode == "r": return TextIOWrapper(f) else: return f elif self.type == "zip": if mode == "rb": mode = "r" return self.fopen.open(info, mode) else: raise ValueError(f"Not supported: type={self.type}")
[docs]def find_path_and_change_it_recursive(value, src: str, tgt: str): if isinstance(value, dict): return { k: find_path_and_change_it_recursive(v, src, tgt) for k, v in value.items() } elif isinstance(value, (list, tuple)): return [find_path_and_change_it_recursive(v, src, tgt) for v in value] elif isinstance(value, str) and Path(value) == Path(src): return tgt else: return value
[docs]def get_dict_from_cache(meta: Union[Path, str]) -> Optional[Dict[str, str]]: meta = Path(meta) outpath = meta.parent.parent if not meta.exists(): return None with meta.open("r", encoding="utf-8") as f: d = yaml.safe_load(f) assert isinstance(d, dict), type(d) yaml_files = d["yaml_files"] files = d["files"] assert isinstance(yaml_files, dict), type(yaml_files) assert isinstance(files, dict), type(files) retval = {} for key, value in list(yaml_files.items()) + list(files.items()): if not (outpath / value).exists(): return None retval[key] = str(outpath / value) return retval
[docs]def unpack( input_archive: Union[Path, str], outpath: Union[Path, str], use_cache: bool = True, ) -> Dict[str, str]: """Scan all files in the archive file and return as a dict of files. Examples: tarfile: model.pth some1.file some2.file >>> unpack("tarfile", "out") {'asr_model_file': 'out/model.pth'} """ input_archive = Path(input_archive) outpath = Path(outpath) with Archiver(input_archive) as archive: for info in archive: if Path(archive.get_name_from_info(info)).name == "meta.yaml": if ( use_cache and (outpath / Path(archive.get_name_from_info(info))).exists() ): retval = get_dict_from_cache( outpath / Path(archive.get_name_from_info(info)) ) if retval is not None: return retval d = yaml.safe_load(archive.extractfile(info)) assert isinstance(d, dict), type(d) yaml_files = d["yaml_files"] files = d["files"] assert isinstance(yaml_files, dict), type(yaml_files) assert isinstance(files, dict), type(files) break else: raise RuntimeError("Format error: not found meta.yaml") for info in archive: fname = archive.get_name_from_info(info) outname = outpath / fname outname.parent.mkdir(parents=True, exist_ok=True) if fname in set(yaml_files.values()): d = yaml.safe_load(archive.extractfile(info)) # Rewrite yaml for info2 in archive: name = archive.get_name_from_info(info2) d = find_path_and_change_it_recursive(d, name, str(outpath / name)) with outname.open("w", encoding="utf-8") as f: yaml.safe_dump(d, f) else: archive.extract(info, path=outpath) retval = {} for key, value in list(yaml_files.items()) + list(files.items()): retval[key] = str(outpath / value) return retval
def _to_relative_or_resolve(f): # Resolve to avoid symbolic link p = Path(f).resolve() try: # Change to relative if it can p = p.relative_to(Path(".").resolve()) except ValueError: pass return str(p)
[docs]def pack( files: Dict[str, Union[str, Path]], yaml_files: Dict[str, Union[str, Path]], outpath: Union[str, Path], option: Iterable[Union[str, Path]] = (), ): for v in list(files.values()) + list(yaml_files.values()) + list(option): if not Path(v).exists(): raise FileNotFoundError(f"No such file or directory: {v}") files = {k: _to_relative_or_resolve(v) for k, v in files.items()} yaml_files = {k: _to_relative_or_resolve(v) for k, v in yaml_files.items()} option = [_to_relative_or_resolve(v) for v in option] meta_objs = dict( files=files, yaml_files=yaml_files, timestamp=datetime.now().timestamp(), python=sys.version, ) try: import torch meta_objs.update(torch=str(torch.__version__)) except ImportError: pass try: import espnet meta_objs.update(espnet=espnet.__version__) except ImportError: pass Path(outpath).parent.mkdir(parents=True, exist_ok=True) with Archiver(outpath, mode="w") as archive: # Write packed/meta.yaml fileobj = BytesIO(yaml.safe_dump(meta_objs).encode()) info = archive.generate_info("meta.yaml", fileobj.getbuffer().nbytes) archive.addfile(info, fileobj=fileobj) for f in list(yaml_files.values()) + list(files.values()) + list(option): archive.add(f) print(f"Generate: {outpath}")