espnet.nets.pytorch_backend.transformer.plot.plot_multi_head_attention
Less than 1 minute
espnet.nets.pytorch_backend.transformer.plot.plot_multi_head_attention
espnet.nets.pytorch_backend.transformer.plot.plot_multi_head_attention(data, uttid_list, attn_dict, outdir, suffix='png', savefn=<function savefig>, ikey='input', iaxis=0, okey='output', oaxis=0, subsampling_factor=4)
Plot multi head attentions.
- Parameters:
- data (dict) – utts info from json file
- uttid_list (List) – utterance IDs
- attn_dict (dict *[*str , torch.Tensor ]) – multi head attention dict. values should be torch.Tensor (head, input_length, output_length)
- outdir (str) – dir to save fig
- suffix (str) – filename suffix including image type (e.g., png)
- savefn – function to save
- ikey (str) – key to access input
- iaxis (int) – dimension to access input
- okey (str) – key to access output
- oaxis (int) – dimension to access output
- subsampling_factor – subsampling factor in encoder