import math
import random
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torchaudio
[docs]def weighted_sample_without_replacement(population, weights, k, rng=random):
if k > len(population):
raise ValueError(
"Cannot take a larger sample than population when without replacement"
)
v = [rng.random() ** (1 / w) for w in weights]
order = sorted(range(len(population)), key=lambda i: v[i])
return [population[i] for i in order[-k:]]
[docs]class DataAugmentation:
"""A series of data augmentation effects that can be applied to a given waveform.
Note: Currently we only support single-channel waveforms.
Args:
effects (list): a list of effects to be applied to the waveform.
Example:
[
[0.1, "lowpass", {"cutoff_freq": 1000, "Q": 0.707}],
[0.1, "highpass", {"cutoff_freq": 3000, "Q": 0.707}],
[0.1, "equalization", {"center_freq": 1000, "gain": 0, "Q": 0.707}],
[
0.1,
[
[0.3, "speed_perturb", {"factor": 0.9}],
[0.3, "speed_perturb", {"factor": 1.1}],
]
],
]
Description:
- The above list defines a series of data augmentation effects that will
be randomly sampled to apply to a given waveform.
- The data structure of each element can be either
type1=Tuple[float, str, Dict] or type2=Tuple[float, type1].
- In type1, the three values are the weight of sampling this effect, the
name (key) of the effect, and the keyword arguments for the effect.
- In type2, the first value is the weight of sampling this effect.
The second value is a list of type1 elements which are similarly
defined as above.
- Note that he effects defined in each type2 data are mutually exclusive
(i.e., only one of them can be applied each time).
This can be useful when you want to avoid applying some specific
effects at the same time.
apply_n (list): range of the number of effects to be applied to the waveform.
"""
def __init__(
self,
effects: List[
Union[
Tuple[float, List[Tuple[float, str, Dict]]],
Tuple[float, str, Dict],
]
],
apply_n: Tuple[int, int] = [1, 1],
):
self.effects = tuple(
[tup[1] if isinstance(tup[1], list) else tup[1:] for tup in effects]
)
self.effect_probs = tuple([tup[0] for tup in effects])
assert apply_n[0] <= apply_n[1], apply_n
assert apply_n[1] > 0, apply_n
self.apply_n = tuple(apply_n)
def __call__(self, waveform, sample_rate):
if isinstance(waveform, np.ndarray):
waveform = torch.from_numpy(waveform)
assert waveform.ndim == 1, waveform.shape
if self.apply_n[1] > self.apply_n[0]:
apply_n = np.random.randint(self.apply_n[0], self.apply_n[1] + 1)
else:
apply_n = self.apply_n[0]
for effect in weighted_sample_without_replacement(
self.effects, weights=self.effect_probs, k=apply_n
):
if isinstance(effect[1], list):
probs = [tup[0] for tup in effect]
_, eff, eff_args = weighted_sample_without_replacement(
effect, weights=probs, k=1
)[0]
else:
eff, eff_args = effect
waveform = self._apply_effect(waveform, sample_rate, eff, eff_args)
return waveform.cpu().numpy()
def _apply_effect(self, waveform, sample_rate, eff, eff_args):
eff_args.pop("sample_rate", None)
return effects_dict[eff](waveform, sample_rate, **eff_args)
[docs]def lowpass_filtering(
waveform, sample_rate: int, cutoff_freq: int = 1000, Q: float = 0.707
):
"""Lowpass filter the input signal.
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz
cutoff_freq (int): filter cutoff frequency
Q (float or torch.Tensor): https://en.wikipedia.org/wiki/Q_factor
Returns:
ret (torch.Tensor): filtered signal (..., time)
"""
ret = torchaudio.functional.lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=Q)
return ret
[docs]def highpass_filtering(
waveform, sample_rate: int, cutoff_freq: int = 3000, Q: float = 0.707
):
"""Highpass filter the input signal.
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz
cutoff_freq (int): filter cutoff frequency
Q (float or torch.Tensor): https://en.wikipedia.org/wiki/Q_factor
Returns:
ret (torch.Tensor): filtered signal (..., time)
"""
ret = torchaudio.functional.highpass_biquad(waveform, sample_rate, cutoff_freq, Q=Q)
return ret
[docs]def bandpass_filtering(
waveform,
sample_rate: int,
center_freq: int = 3000,
Q: float = 0.707,
const_skirt_gain: bool = False,
):
"""Bandpass filter the input signal.
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz
center_freq_freq (int): filter's center_freq frequency
Q (float or torch.Tensor): https://en.wikipedia.org/wiki/Q_factor
const_skirt_gain (bool): If True, uses a constant skirt gain (peak gain = Q).
If False, uses a constant 0dB peak gain.
Returns:
ret (torch.Tensor): filtered signal (..., time)
"""
ret = torchaudio.functional.bandpass_biquad(
waveform, sample_rate, center_freq, Q=Q, const_skirt_gain=const_skirt_gain
)
return ret
[docs]def bandreject_filtering(
waveform, sample_rate: int, center_freq: int = 3000, Q: float = 0.707
):
"""Two-pole band-reject filter the input signal.
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz
center_freq_freq (int): filter's center_freq frequency
Q (float or torch.Tensor): https://en.wikipedia.org/wiki/Q_factor
Returns:
ret (torch.Tensor): filtered signal (..., time)
"""
ret = torchaudio.functional.bandreject_biquad(
waveform, sample_rate, center_freq, Q=Q
)
return ret
[docs]def contrast(waveform, sample_rate: int = 16000, enhancement_amount: float = 75.0):
"""Apply contrast effect to the input signal to make it sound louder.
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz (not used)
enhancement_amount (float): controls the amount of the enhancement
Allowed range of values for enhancement_amount : 0-100
Note that enhancement_amount = 0 still gives a significant
contrast enhancement.
Returns:
ret (torch.Tensor): filtered signal (..., time)
"""
ret = torchaudio.functional.contrast(waveform, enhancement_amount)
return ret
[docs]def equalization_filtering(
waveform,
sample_rate: int,
center_freq: int = 1000,
gain: float = 0.0,
Q: float = 0.707,
):
"""Equalization filter the input signal.
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz
center_freq (int): filter's center frequency
gain (float or torch.Tensor): desired gain at the boost (or attenuation) in dB
Q (float or torch.Tensor): https://en.wikipedia.org/wiki/Q_factor
Returns:
ret (torch.Tensor): filtered signal (..., time)
"""
ret = torchaudio.functional.equalizer_biquad(
waveform, sample_rate, center_freq, gain, Q=Q
)
return ret
[docs]def pitch_shift(
waveform,
sample_rate: int,
n_steps: int,
bins_per_octave: int = 12,
n_fft: float = 0.032,
win_length: Optional[float] = None,
hop_length: float = 0.008,
window: Optional[str] = "hann",
):
"""Shift the pitch of a waveform by `n_steps` steps.
Note: this function is slow.
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz
n_steps (int): the (fractional) steps to shift the pitch
-4 for shifting pitch down by 4/`bins_per_octave` octaves
4 for shifting pitch up by 4/`bins_per_octave` octaves
bins_per_octave (int): number of steps per octave
n_fft (float): length of FFT (in second)
win_length (float or None): The window length (in second) used for STFT
If None, it is treated as equal to n_fft
hop_length (float): The hop size (in second) used for STFT
window (str or None): The windowing function applied to the signal after
padding with zeros
Returns:
ret (torch.Tensor): filtered signal (..., time)
"""
n_fft = int(sample_rate * n_fft)
if hop_length is None:
hop_length = n_fft // 4
else:
hop_length = int(sample_rate * hop_length)
if win_length is None:
win_length = n_fft
if window is not None:
window_func = getattr(torch, f"{window}_window")
window = window_func(win_length, dtype=waveform.dtype, device=waveform.device)
ret = torchaudio.functional.pitch_shift(
waveform,
sample_rate,
n_steps,
bins_per_octave=bins_per_octave,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window=window,
)
return ret
[docs]def speed_perturb(waveform, sample_rate: int, factor: float):
"""Speed perturbation which also changes the pitch.
Note: This function should be used with caution as it changes the signal duration.
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz
factor (float): speed factor (e.g., 0.9 for 90% speed)
lengths (torch.Tensor): lengths of the input signals
Returns:
ret (torch.Tensor): perturbed signal (..., time)
"""
orig_freq = sample_rate
source_sample_rate = int(factor * orig_freq)
target_sample_rate = int(orig_freq)
gcd = math.gcd(source_sample_rate, target_sample_rate)
source_sample_rate = source_sample_rate // gcd
target_sample_rate = target_sample_rate // gcd
ret = torchaudio.functional.resample(
waveform, source_sample_rate, target_sample_rate
)
return ret
[docs]def time_stretch(
waveform,
sample_rate: int,
factor: float,
n_fft: float = 0.032,
win_length: Optional[float] = None,
hop_length: float = 0.008,
window: Optional[str] = "hann",
):
"""Time scaling (speed up in time without modifying pitch) via phase vocoder.
Note: This function should be used with caution as it changes the signal duration.
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz
factor (float): speed-up factor (e.g., 0.9 for 90% speed and 1.3 for 130% speed)
n_fft (float): length of FFT (in second)
win_length (float or None): The window length (in second) used for STFT
If None, it is treated as equal to n_fft
hop_length (float): The hop size (in second) used for STFT
window (str or None): The windowing function applied to the signal after
padding with zeros
Returns:
ret (torch.Tensor): perturbed signal (..., time)
"""
n_fft = int(sample_rate * n_fft)
if hop_length is None:
hop_length = n_fft // 4
else:
hop_length = int(sample_rate * hop_length)
if win_length is None:
win_length = n_fft
if window is not None:
window_func = getattr(torch, f"{window}_window")
window = window_func(win_length, dtype=waveform.dtype, device=waveform.device)
spec = torch.stft(
waveform, n_fft, hop_length, win_length, window=window, return_complex=True
)
freq = spec.size(-2)
phase_advance = torch.linspace(0, math.pi * hop_length, freq)[..., None]
spec_sp = torchaudio.functional.phase_vocoder(spec, factor, phase_advance)
len_stretch = int(round(waveform.size(-1) / factor))
ret = torch.functional.istft(
spec_sp, n_fft, hop_length, win_length, window=window, length=len_stretch
)
return ret
[docs]def codecs(
waveform,
sample_rate: int,
format: str,
compression: Optional[float] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
"""Apply the specified codecs to the input signal.
Warning: Wait until torchaudio 2.1 for this function to work.
Note:
1. This function only supports CPU backend.
2. The GSM codec can be used to emulate phone line channel effects.
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz
format (str): file format.
Valid values are "wav", "mp3", "ogg", "vorbis", "amr-nb", "amb",
"flac", "sph", "gsm", and "htk".
compression (float or None, optional): used for formats other than WAV
For more details see torchaudio.backend.sox_io_backend.save().
encoding (str or None, optional): change the encoding for the supported formats
Valid values are "PCM_S" (signed integer Linear PCM),
"PCM_U" (unsigned integer Linear PCM), "PCM_F" (floating point PCM),
"ULAW" (mu-law), and "ALAW" (a-law).
For more details see torchaudio.backend.sox_io_backend.save().
bits_per_sample (int or None, optional): change the bit depth
for the supported formats
For more details see torchaudio.backend.sox_io_backend.save().
Returns:
ret (torch.Tensor): compressed signal (..., time)
"""
raise NotImplementedError
ret = torchaudio.functional.apply_codec(
waveform.unsqueeze(0),
sample_rate,
format,
channels_first=False,
compression=compression,
encoding=encoding,
bits_per_sample=bits_per_sample,
)
return ret.squeeze(0)
[docs]def preemphasis(waveform, sample_rate: int, coeff: float = 0.97):
"""Pre-emphasize a waveform along the time dimension.
y[i] = x[i] - coeff * x[i - 1]
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz (not used)
coeff (float): pre-emphasis coefficient. Typically between 0.0 and 1.0.
Returns:
ret (torch.Tensor): pre-emphasized signal (..., time)
"""
waveform = waveform.clone()
waveform[..., 1:] -= coeff * waveform[..., :-1]
return waveform
[docs]def deemphasis(waveform, sample_rate: int, coeff: float = 0.97):
"""De-emphasize a waveform along the time dimension.
y[i] = x[i] + coeff * y[i - 1]
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz (not used)
coeff (float): de-emphasis coefficient. Typically between 0.0 and 1.0.
Returns:
ret (torch.Tensor): de-emphasized signal (..., time)
"""
a_coeffs = waveform.new_tensor([1.0, -coeff])
b_coeffs = waveform.new_tensor([1.0, 0.0])
return torchaudio.functional.lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)
[docs]def clipping(
waveform, sample_rate: int, min_quantile: float = 0.0, max_quantile: float = 0.9
):
"""Apply the clipping distortion to the input signal.
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz (not used)
min_quantile (float): lower bound on the total percent of samples to be clipped
max_quantile (float): upper bound on the total percent of samples to be clipped
Returns:
ret (torch.Tensor): clipped signal (..., time)
"""
q = waveform.new_tensor([min_quantile, max_quantile])
min_, max_ = torch.quantile(waveform, q, dim=-1, keepdim=True)
ret = torch.clamp(waveform, min_, max_)
return ret
[docs]def polarity_inverse(waveform, sample_rate):
return -waveform
[docs]def reverse(waveform, sample_rate):
return torch.flip(waveform, [-1])
[docs]def corrupt_phase(
waveform,
sample_rate,
scale: float = 0.5,
n_fft: float = 0.032,
win_length: Optional[float] = None,
hop_length: float = 0.008,
window: Optional[str] = "hann",
):
"""Adding random noise to the phase of input waveform.
Args:
waveform (torch.Tensor): audio signal (..., time)
sample_rate (int): sampling rate in Hz
scale (float): scale factor for the phase noise
n_fft (float): length of FFT (in second)
win_length (float or None): The window length (in second) used for STFT
If None, it is treated as equal to n_fft
hop_length (float): The hop size (in second) used for STFT
window (str or None): The windowing function applied to the signal after
padding with zeros
Returns:
ret (torch.Tensor): phase-corrupted signal (..., time)
"""
n_fft = int(sample_rate * n_fft)
if hop_length is None:
hop_length = n_fft // 4
else:
hop_length = int(sample_rate * hop_length)
if win_length is None:
win_length = n_fft
if window is not None:
window_func = getattr(torch, f"{window}_window")
window = window_func(win_length, dtype=waveform.dtype, device=waveform.device)
spec = torch.stft(
waveform, n_fft, hop_length, win_length, window=window, return_complex=True
)
phase = torch.angle(spec)
phase = torch.randn_like(phase) * scale + phase
spec = torch.abs(spec) * torch.exp(1j * phase)
ret = torch.functional.istft(
spec, n_fft, hop_length, win_length, window=window, length=waveform.size(-1)
)
return ret
effects_dict = {
"lowpass": lowpass_filtering,
"highpass": highpass_filtering,
"bandpass": bandpass_filtering,
"bandreject": bandreject_filtering,
"contrast": contrast,
"equalization": equalization_filtering,
"pitch_shift": pitch_shift,
"speed_perturb": speed_perturb,
"time_stretch": time_stretch,
"preemphasis": preemphasis,
"deemphasis": deemphasis,
"clipping": clipping,
"polarity_inverse": polarity_inverse,
"reverse": reverse,
"corrupt_phase": corrupt_phase,
}