espnet.utils.spec_augment.cross_squared_distance_matrix
Less than 1 minute
espnet.utils.spec_augment.cross_squared_distance_matrix
espnet.utils.spec_augment.cross_squared_distance_matrix(x, y)
Pairwise squared distance between two (batch) matrices’ rows (2nd dim).
Computes the pairwise distances between rows of x and rows of y Args: x: [batch_size, n, d] float Tensor y: [batch_size, m, d] float Tensor Returns: squared_dists: [batch_size, n, m] float Tensor, where squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2