espnet2.torch_utils package


espnet2.torch_utils.load_pretrained_model.load_pretrained_model(init_param: str, model: torch.nn.modules.module.Module, map_location: str = 'cpu')[source]

Load a model state and set it to the model.


init_param – <file_path>:<src_key>:<dst_key>:<exclude_Keys>


>>> load_pretrained_model("somewhere/model.pth", model)
>>> load_pretrained_model("somewhere/model.pth:decoder:decoder", model)
>>> load_pretrained_model("somewhere/model.pth:decoder:decoder:", model)
>>> load_pretrained_model(
...     "somewhere/model.pth:decoder:decoder:decoder.embed", model
... )
>>> load_pretrained_model("somewhere/decoder.pth::decoder", model)


espnet2.torch_utils.device_funcs.force_gatherable(data, device)[source]

Change object to gatherable in torch.nn.DataParallel recursively

The difference from to_device() is changing to torch.Tensor if float or int value is found.

The restriction to the returned value in DataParallel:

The object must be - torch.cuda.Tensor - 1 or more dimension. 0-dimension-tensor sends warning. or a list, tuple, dict.

espnet2.torch_utils.device_funcs.to_device(data, device=None, dtype=None, non_blocking=False, copy=False)[source]

Change the device of object recursively


espnet2.torch_utils.add_gradient_noise.add_gradient_noise(model: torch.nn.modules.module.Module, iteration: int, duration: float = 100, eta: float = 1.0, scale_factor: float = 0.55)[source]

Adds noise from a standard normal distribution to the gradients.

The standard deviation (sigma) is controlled by the three hyper-parameters below. sigma goes to zero (no noise) with more iterations.

  • model – Model.

  • iteration – Number of iterations.

  • duration – {100, 1000}: Number of durations to control the interval of the sigma change.

  • eta – {0.01, 0.3, 1.0}: The magnitude of sigma.

  • scale_factor – {0.55}: The scale of sigma.


espnet2.torch_utils.set_all_random_seed.set_all_random_seed(seed: int)[source]


espnet2.torch_utils.pytorch_version.pytorch_cudnn_version() → str[source]


espnet2.torch_utils.recursive_op.recursive_average(obj, weight: torch.Tensor, distributed: bool = False)[source]
espnet2.torch_utils.recursive_op.recursive_divide(a, b: torch.Tensor)[source]
espnet2.torch_utils.recursive_op.recursive_sum(obj, weight: torch.Tensor, distributed: bool = False)[source]


espnet2.torch_utils.initialize.initialize(model: torch.nn.modules.module.Module, init: str)[source]


espnet2.torch_utils.model_summary.get_human_readable_count(number: int) → str[source]

Return human_readable_count

Originated from:

Abbreviates an integer number with K, M, B, T for thousands, millions, billions and trillions, respectively. .. rubric:: Examples

>>> get_human_readable_count(123)
'123  '
>>> get_human_readable_count(1234)  # (one thousand)
'1 K'
>>> get_human_readable_count(2e6)   # (two million)
'2 M'
>>> get_human_readable_count(3e9)   # (three billion)
'3 B'
>>> get_human_readable_count(4e12)  # (four trillion)
'4 T'
>>> get_human_readable_count(5e15)  # (more than trillion)
'5,000 T'

number – a positive integer number


A string formatted according to the pattern described above.

espnet2.torch_utils.model_summary.model_summary(model: torch.nn.modules.module.Module) → str[source]
espnet2.torch_utils.model_summary.to_bytes(dtype) → int[source]



class espnet2.torch_utils.forward_adaptor.ForwardAdaptor(module: torch.nn.modules.module.Module, name: str)[source]

Bases: torch.nn.modules.module.Module

Wrapped module to parallelize specified method

torch.nn.DataParallel parallelizes only “forward()” and, maybe, the method having the other name can’t be applied except for wrapping the module just like this class.


>>> class A(torch.nn.Module):
...     def foo(self, x):
...         ...
>>> model = A()
>>> model = ForwardAdaptor(model, "foo")
>>> model = torch.nn.DataParallel(model, device_ids=[0, 1])
>>> x = torch.randn(2, 10)
>>> model(x)
forward(*args, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.


Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.