espnet.asr.asr_mix_utils.PlotAttentionReport
Less than 1 minute
espnet.asr.asr_mix_utils.PlotAttentionReport
class espnet.asr.asr_mix_utils.PlotAttentionReport(att_vis_fn, data, outdir, converter, device, reverse=False)
Bases: Extension
Plot attention reporter.
- Parameters:
- att_vis_fn (espnet.nets.*_backend.e2e_asr.calculate_all_attentions) – Function of attention visualization.
- data (list *[*tuple *(*str , dict *[*str , dict *[*str , Any ] ] ) ]) – List json utt key items.
- outdir (str) – Directory to save figures.
- converter (espnet.asr.*_backend.asr.CustomConverter) – CustomConverter object. Function to convert data.
- device (torch.device) – The destination device to send tensor.
- reverse (bool) – If True, input and output length are reversed.
Initialize PlotAttentionReport.
draw_attention_plot(att_w)
Visualize attention weights matrix.
- Parameters:att_w (Tensor) – Attention weight matrix.
- Returns: pyplot object with attention matrix image.
- Return type: matplotlib.pyplot
get_attention_weight(idx, att_w, spkr_idx)
Transform attention weight in regard to self.reverse.
get_attention_weights()
Return attention weights.
- Returns: attention weights. It’s shape would be : differ from bachend.dtype=float * pytorch-> 1) multi-head case => (B, H, Lmax, Tmax). 2)
other case => (B, Lmax, Tmax).
- chainer-> attention weights (B, Lmax, Tmax).
- Return type: arr_ws_sd (numpy.ndarray)
log_attentions(logger, step)
Add image files of attention matrix to tensorboard.