"""Torch utility module."""
import torch
if torch.distributed.is_available():
from torch.distributed import ReduceOp
[docs]def recursive_sum(obj, weight: torch.Tensor, distributed: bool = False):
assert weight.dim() == 1, weight.size()
if isinstance(obj, (tuple, list)):
return type(obj)(recursive_sum(v, weight, distributed) for v in obj)
elif isinstance(obj, dict):
return {k: recursive_sum(v, weight, distributed) for k, v in obj.items()}
elif isinstance(obj, torch.Tensor):
assert obj.size() == weight.size(), (obj.size(), weight.size())
obj = (obj * weight.type(obj.dtype)).sum()
if distributed:
lst = [
torch.empty_like(obj) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(lst, obj)
if all([torch.isnan(o) for o in lst]):
obj = torch.sum(torch.stack(lst))
else:
# NOTE(wangyou): not using torch.nansum here to compensate for the
# reduced samples.
# This is important so that the condition-specific loss values reported
# in Reporter will be consistent with the general loss value.
obj = torch.nanmean(torch.stack(lst)) * len(lst)
return obj
elif obj is None:
return None
else:
raise ValueError(type(obj))
[docs]def recursive_divide(a, b: torch.Tensor):
if isinstance(a, (tuple, list)):
return type(a)(recursive_divide(v, b) for v in a)
elif isinstance(a, dict):
return {k: recursive_divide(v, b) for k, v in a.items()}
elif isinstance(a, torch.Tensor):
assert a.size() == b.size(), (a.size(), b.size())
return a / b.type(a.dtype)
elif a is None:
return None
else:
raise ValueError(type(a))
[docs]def recursive_average(obj, weight: torch.Tensor, distributed: bool = False):
obj = recursive_sum(obj, weight, distributed)
weight = weight.sum()
if distributed:
torch.distributed.all_reduce(weight, op=ReduceOp.SUM)
# Normalize weight to be sum-to-1
obj = recursive_divide(obj, weight)
return obj, weight