espnet.nets.pytorch_backend.transformer.plot.PlotAttentionReport
Less than 1 minute
espnet.nets.pytorch_backend.transformer.plot.PlotAttentionReport
class espnet.nets.pytorch_backend.transformer.plot.PlotAttentionReport(att_vis_fn, data, outdir, converter, transform, device, reverse=False, ikey='input', iaxis=0, okey='output', oaxis=0, subsampling_factor=1)
Bases: PlotAttentionReport
get_attention_weights()
Return attention weights.
- Returns: attention weights. float. Its shape would be : differ from backend. * pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
other case => (B, Lmax, Tmax).
- chainer-> (B, Lmax, Tmax)
- Return type: numpy.ndarray
log_attentions(logger, step)
Add image files of att_ws matrix to the tensorboard.
plotfn(*args, **kwargs)