Source code for espnet.nets.pytorch_backend.transformer.plot

# Copyright 2019 Shigeki Karita
#  Apache 2.0  (

import logging
import os

import numpy

from espnet.asr import asr_utils

def _plot_and_save_attention(att_w, filename, xtokens=None, ytokens=None):
    import matplotlib

    import matplotlib.pyplot as plt
    from matplotlib.ticker import MaxNLocator

    d = os.path.dirname(filename)
    if not os.path.exists(d):
    w, h = plt.figaspect(1.0 / len(att_w))
    fig = plt.Figure(figsize=(w * 2, h * 2))
    axes = fig.subplots(1, len(att_w))
    if len(att_w) == 1:
        axes = [axes]
    for ax, aw in zip(axes, att_w):
        # plt.subplot(1, len(att_w), h)
        ax.imshow(aw.astype(numpy.float32), aspect="auto")
        # Labels for major ticks
        if xtokens is not None:
            ax.set_xticks(numpy.linspace(0, len(xtokens), len(xtokens) + 1))
            ax.set_xticks(numpy.linspace(0, len(xtokens), 1), minor=True)
            ax.set_xticklabels(xtokens + [""], rotation=40)
        if ytokens is not None:
            ax.set_yticks(numpy.linspace(0, len(ytokens), len(ytokens) + 1))
            ax.set_yticks(numpy.linspace(0, len(ytokens), 1), minor=True)
            ax.set_yticklabels(ytokens + [""])
    return fig

[docs]def savefig(plot, filename): import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt plot.savefig(filename) plt.clf()
[docs]def plot_multi_head_attention( data, uttid_list, attn_dict, outdir, suffix="png", savefn=savefig, ikey="input", iaxis=0, okey="output", oaxis=0, subsampling_factor=4, ): """Plot multi head attentions. :param dict data: utts info from json file :param List uttid_list: utterance IDs :param dict[str, torch.Tensor] attn_dict: multi head attention dict. values should be torch.Tensor (head, input_length, output_length) :param str outdir: dir to save fig :param str suffix: filename suffix including image type (e.g., png) :param savefn: function to save :param str ikey: key to access input :param int iaxis: dimension to access input :param str okey: key to access output :param int oaxis: dimension to access output :param subsampling_factor: subsampling factor in encoder """ for name, att_ws in attn_dict.items(): for idx, att_w in enumerate(att_ws): data_i = data[uttid_list[idx]] filename = "%s/%s.%s.%s" % (outdir, uttid_list[idx], name, suffix) dec_len = int(data_i[okey][oaxis]["shape"][0]) + 1 # +1 for <eos> enc_len = int(data_i[ikey][iaxis]["shape"][0]) is_mt = "token" in data_i[ikey][iaxis].keys() # for ASR/ST if not is_mt: enc_len //= subsampling_factor xtokens, ytokens = None, None if "encoder" in name: att_w = att_w[:, :enc_len, :enc_len] # for MT if is_mt: xtokens = data_i[ikey][iaxis]["token"].split() ytokens = xtokens[:] elif "decoder" in name: if "self" in name: # self-attention att_w = att_w[:, :dec_len, :dec_len] if "token" in data_i[okey][oaxis].keys(): ytokens = data_i[okey][oaxis]["token"].split() + ["<eos>"] xtokens = ["<sos>"] + data_i[okey][oaxis]["token"].split() else: # cross-attention att_w = att_w[:, :dec_len, :enc_len] if "token" in data_i[okey][oaxis].keys(): ytokens = data_i[okey][oaxis]["token"].split() + ["<eos>"] # for MT if is_mt: xtokens = data_i[ikey][iaxis]["token"].split() else: logging.warning("unknown name for shaping attention") fig = _plot_and_save_attention(att_w, filename, xtokens, ytokens) savefn(fig, filename)
[docs]class PlotAttentionReport(asr_utils.PlotAttentionReport):
[docs] def plotfn(self, *args, **kwargs): kwargs["ikey"] = self.ikey kwargs["iaxis"] = self.iaxis kwargs["okey"] = self.okey kwargs["oaxis"] = self.oaxis kwargs["subsampling_factor"] = self.factor plot_multi_head_attention(*args, **kwargs)
def __call__(self, trainer): attn_dict, uttid_list = self.get_attention_weights() suffix = "ep.{.updater.epoch}.png".format(trainer) self.plotfn(self.data_dict, uttid_list, attn_dict, self.outdir, suffix, savefig)
[docs] def get_attention_weights(self): return_batch, uttid_list = self.transform(, return_uttid=True) batch = self.converter([return_batch], self.device) if isinstance(batch, tuple): att_ws = self.att_vis_fn(*batch) elif isinstance(batch, dict): att_ws = self.att_vis_fn(**batch) return att_ws, uttid_list
[docs] def log_attentions(self, logger, step): def log_fig(plot, filename): import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt logger.add_figure(os.path.basename(filename), plot, step) plt.clf() attn_dict, uttid_list = self.get_attention_weights() self.plotfn(self.data_dict, uttid_list, attn_dict, self.outdir, "", log_fig)