Source code for espnet.nets.pytorch_backend.e2e_asr_mulenc

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Copyright 2017 Johns Hopkins University (Ruizhi Li)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Define e2e module for multi-encoder network. https://arxiv.org/pdf/1811.04903.pdf."""

import argparse
import logging
import math
import os
from itertools import groupby

import chainer
import numpy as np
import torch
from chainer import reporter

from espnet.nets.asr_interface import ASRInterface
from espnet.nets.e2e_asr_common import label_smoothing_dist
from espnet.nets.pytorch_backend.ctc import ctc_for
from espnet.nets.pytorch_backend.nets_utils import (
    get_subsample,
    pad_list,
    to_device,
    to_torch_tensor,
)
from espnet.nets.pytorch_backend.rnn.attentions import att_for
from espnet.nets.pytorch_backend.rnn.decoders import decoder_for
from espnet.nets.pytorch_backend.rnn.encoders import Encoder, encoder_for
from espnet.nets.scorers.ctc import CTCPrefixScorer
from espnet.utils.cli_utils import strtobool

CTC_LOSS_THRESHOLD = 10000


[docs]class Reporter(chainer.Chain): """Define a chainer reporter wrapper."""
[docs] def report(self, loss_ctc_list, loss_att, acc, cer_ctc_list, cer, wer, mtl_loss): """Define a chainer reporter function.""" # loss_ctc_list = [weighted CTC, CTC1, CTC2, ... CTCN] # cer_ctc_list = [weighted cer_ctc, cer_ctc_1, cer_ctc_2, ... cer_ctc_N] num_encs = len(loss_ctc_list) - 1 reporter.report({"loss_ctc": loss_ctc_list[0]}, self) for i in range(num_encs): reporter.report({"loss_ctc{}".format(i + 1): loss_ctc_list[i + 1]}, self) reporter.report({"loss_att": loss_att}, self) reporter.report({"acc": acc}, self) reporter.report({"cer_ctc": cer_ctc_list[0]}, self) for i in range(num_encs): reporter.report({"cer_ctc{}".format(i + 1): cer_ctc_list[i + 1]}, self) reporter.report({"cer": cer}, self) reporter.report({"wer": wer}, self) logging.info("mtl loss:" + str(mtl_loss)) reporter.report({"loss": mtl_loss}, self)
[docs]class E2E(ASRInterface, torch.nn.Module): """E2E module. :param List idims: List of dimensions of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """
[docs] @staticmethod def add_arguments(parser): """Add arguments for multi-encoder setting.""" E2E.encoder_add_arguments(parser) E2E.attention_add_arguments(parser) E2E.decoder_add_arguments(parser) E2E.ctc_add_arguments(parser) return parser
[docs] @staticmethod def encoder_add_arguments(parser): """Add arguments for encoders in multi-encoder setting.""" group = parser.add_argument_group("E2E encoder setting") group.add_argument( "--etype", action="append", type=str, choices=[ "lstm", "blstm", "lstmp", "blstmp", "vgglstmp", "vggblstmp", "vgglstm", "vggblstm", "gru", "bgru", "grup", "bgrup", "vgggrup", "vggbgrup", "vgggru", "vggbgru", ], help="Type of encoder network architecture", ) group.add_argument( "--elayers", type=int, action="append", help="Number of encoder layers " "(for shared recognition part in multi-speaker asr mode)", ) group.add_argument( "--eunits", "-u", type=int, action="append", help="Number of encoder hidden units", ) group.add_argument( "--eprojs", default=320, type=int, help="Number of encoder projection units" ) group.add_argument( "--subsample", type=str, action="append", help="Subsample input frames x_y_z means " "subsample every x frame at 1st layer, " "every y frame at 2nd layer etc.", ) return parser
[docs] @staticmethod def attention_add_arguments(parser): """Add arguments for attentions in multi-encoder setting.""" group = parser.add_argument_group("E2E attention setting") # attention group.add_argument( "--atype", type=str, action="append", choices=[ "noatt", "dot", "add", "location", "coverage", "coverage_location", "location2d", "location_recurrent", "multi_head_dot", "multi_head_add", "multi_head_loc", "multi_head_multi_res_loc", ], help="Type of attention architecture", ) group.add_argument( "--adim", type=int, action="append", help="Number of attention transformation dimensions", ) group.add_argument( "--awin", type=int, action="append", help="Window size for location2d attention", ) group.add_argument( "--aheads", type=int, action="append", help="Number of heads for multi head attention", ) group.add_argument( "--aconv-chans", type=int, action="append", help="Number of attention convolution channels \ (negative value indicates no location-aware attention)", ) group.add_argument( "--aconv-filts", type=int, action="append", help="Number of attention convolution filters \ (negative value indicates no location-aware attention)", ) group.add_argument( "--dropout-rate", type=float, action="append", help="Dropout rate for the encoder", ) # hierarchical attention network (HAN) group.add_argument( "--han-type", default="dot", type=str, choices=[ "noatt", "dot", "add", "location", "coverage", "coverage_location", "location2d", "location_recurrent", "multi_head_dot", "multi_head_add", "multi_head_loc", "multi_head_multi_res_loc", ], help="Type of attention architecture (multi-encoder asr mode only)", ) group.add_argument( "--han-dim", default=320, type=int, help="Number of attention transformation dimensions in HAN", ) group.add_argument( "--han-win", default=5, type=int, help="Window size for location2d attention in HAN", ) group.add_argument( "--han-heads", default=4, type=int, help="Number of heads for multi head attention in HAN", ) group.add_argument( "--han-conv-chans", default=-1, type=int, help="Number of attention convolution channels in HAN \ (negative value indicates no location-aware attention)", ) group.add_argument( "--han-conv-filts", default=100, type=int, help="Number of attention convolution filters in HAN \ (negative value indicates no location-aware attention)", ) return parser
[docs] @staticmethod def decoder_add_arguments(parser): """Add arguments for decoder in multi-encoder setting.""" group = parser.add_argument_group("E2E decoder setting") group.add_argument( "--dtype", default="lstm", type=str, choices=["lstm", "gru"], help="Type of decoder network architecture", ) group.add_argument( "--dlayers", default=1, type=int, help="Number of decoder layers" ) group.add_argument( "--dunits", default=320, type=int, help="Number of decoder hidden units" ) group.add_argument( "--dropout-rate-decoder", default=0.0, type=float, help="Dropout rate for the decoder", ) group.add_argument( "--sampling-probability", default=0.0, type=float, help="Ratio of predicted labels fed back to decoder", ) group.add_argument( "--lsm-type", const="", default="", type=str, nargs="?", choices=["", "unigram"], help="Apply label smoothing with a specified distribution type", ) return parser
[docs] @staticmethod def ctc_add_arguments(parser): """Add arguments for ctc in multi-encoder setting.""" group = parser.add_argument_group("E2E multi-ctc setting") group.add_argument( "--share-ctc", type=strtobool, default=False, help="The flag to switch to share ctc across multiple encoders " "(multi-encoder asr mode only).", ) group.add_argument( "--weights-ctc-train", type=float, action="append", help="ctc weight assigned to each encoder during training.", ) group.add_argument( "--weights-ctc-dec", type=float, action="append", help="ctc weight assigned to each encoder during decoding.", ) return parser
[docs] def get_total_subsampling_factor(self): """Get total subsampling factor.""" if isinstance(self.enc, Encoder): return self.enc.conv_subsampling_factor * int( np.prod(self.subsample_list[0]) ) else: return self.enc[0].conv_subsampling_factor * int( np.prod(self.subsample_list[0]) )
def __init__(self, idims, odim, args): """Initialize this class with python-level args. Args: idims (list): list of the number of an input feature dim. odim (int): The number of output vocab. args (Namespace): arguments """ super(E2E, self).__init__() torch.nn.Module.__init__(self) self.mtlalpha = args.mtlalpha assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() self.num_encs = args.num_encs self.share_ctc = args.share_ctc # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 # subsample info self.subsample_list = get_subsample(args, mode="asr", arch="rnn_mulenc") # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist( odim, args.lsm_type, transcript=args.train_json ) else: labeldist = None # speech translation related self.replace_sos = getattr( args, "replace_sos", False ) # use getattr to keep compatibility self.frontend = None # encoder self.enc = encoder_for(args, idims, self.subsample_list) # ctc self.ctc = ctc_for(args, odim) # attention self.att = att_for(args) # hierarchical attention network han = att_for(args, han_mode=True) self.att.append(han) # decoder self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) if args.mtlalpha > 0 and self.num_encs > 1: # weights-ctc, # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss self.weights_ctc_train = args.weights_ctc_train / np.sum( args.weights_ctc_train ) # normalize self.weights_ctc_dec = args.weights_ctc_dec / np.sum( args.weights_ctc_dec ) # normalize logging.info( "ctc weights (training during training): " + " ".join([str(x) for x in self.weights_ctc_train]) ) logging.info( "ctc weights (decoding during training): " + " ".join([str(x) for x in self.weights_ctc_dec]) ) else: self.weights_ctc_dec = [1.0] self.weights_ctc_train = [1.0] # weight initialization self.init_like_chainer() # options for beam search if args.report_cer or args.report_wer: recog_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": args.ctc_weight, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, "tgt_lang": False, "ctc_weights_dec": self.weights_ctc_dec, } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
[docs] def init_like_chainer(self): """Initialize weight like chainer. chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) however, there are two exceptions as far as I know. - EmbedID.W ~ Normal(0, 1) - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) """ def lecun_normal_init_parameters(module): for p in module.parameters(): data = p.data if data.dim() == 1: # bias data.zero_() elif data.dim() == 2: # linear weight n = data.size(1) stdv = 1.0 / math.sqrt(n) data.normal_(0, stdv) elif data.dim() in (3, 4): # conv weight n = data.size(1) for k in data.size()[2:]: n *= k stdv = 1.0 / math.sqrt(n) data.normal_(0, stdv) else: raise NotImplementedError def set_forget_bias_to_one(bias): n = bias.size(0) start, end = n // 4, n // 2 bias.data[start:end].fill_(1.0) lecun_normal_init_parameters(self) # exceptions # embed weight ~ Normal(0, 1) self.dec.embed.weight.data.normal_(0, 1) # forget-bias = 1.0 # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745 for i in range(len(self.dec.decoder)): set_forget_bias_to_one(self.dec.decoder[i].bias_ih)
[docs] def forward(self, xs_pad_list, ilens_list, ys_pad): """E2E forward. :param List xs_pad_list: list of batch (torch.Tensor) of padded input sequences [(B, Tmax_1, idim), (B, Tmax_2, idim),..] :param List ilens_list: list of batch (torch.Tensor) of lengths of input sequences [(B), (B), ..] :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: loss value :rtype: torch.Tensor """ import editdistance if self.replace_sos: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beginning else: tgt_lang_ids = None hs_pad_list, hlens_list, self.loss_ctc_list = [], [], [] for idx in range(self.num_encs): # 1. Encoder hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx]) # 2. CTC loss if self.mtlalpha == 0: self.loss_ctc_list.append(None) else: ctc_idx = 0 if self.share_ctc else idx loss_ctc = self.ctc[ctc_idx](hs_pad, hlens, ys_pad) self.loss_ctc_list.append(loss_ctc) hs_pad_list.append(hs_pad) hlens_list.append(hlens) # 3. attention loss if self.mtlalpha == 1: self.loss_att, acc = None, None else: self.loss_att, acc, _ = self.dec( hs_pad_list, hlens_list, ys_pad, lang_ids=tgt_lang_ids ) self.acc = acc # 4. compute cer without beam search if self.mtlalpha == 0 or self.char_list is None: cer_ctc_list = [None] * (self.num_encs + 1) else: cer_ctc_list = [] for ind in range(self.num_encs): cers = [] ctc_idx = 0 if self.share_ctc else ind y_hats = self.ctc[ctc_idx].argmax(hs_pad_list[ind]).data for i, y in enumerate(y_hats): y_hat = [x[0] for x in groupby(y)] y_true = ys_pad[i] seq_hat = [ self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 ] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace(self.space, " ") seq_hat_text = seq_hat_text.replace(self.blank, "") seq_true_text = "".join(seq_true).replace(self.space, " ") hyp_chars = seq_hat_text.replace(" ", "") ref_chars = seq_true_text.replace(" ", "") if len(ref_chars) > 0: cers.append( editdistance.eval(hyp_chars, ref_chars) / len(ref_chars) ) cer_ctc = sum(cers) / len(cers) if cers else None cer_ctc_list.append(cer_ctc) cer_ctc_weighted = np.sum( [ item * self.weights_ctc_train[i] for i, item in enumerate(cer_ctc_list) ] ) cer_ctc_list = [float(cer_ctc_weighted)] + [ float(item) for item in cer_ctc_list ] # 5. compute cer/wer if self.training or not (self.report_cer or self.report_wer): cer, wer = 0.0, 0.0 # oracle_cer, oracle_wer = 0.0, 0.0 else: if self.recog_args.ctc_weight > 0.0: lpz_list = [] for idx in range(self.num_encs): ctc_idx = 0 if self.share_ctc else idx lpz = self.ctc[ctc_idx].log_softmax(hs_pad_list[idx]).data lpz_list.append(lpz) else: lpz_list = None word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] nbest_hyps = self.dec.recognize_beam_batch( hs_pad_list, hlens_list, lpz_list, self.recog_args, self.char_list, self.rnnlm, lang_ids=tgt_lang_ids.squeeze(1).tolist() if self.replace_sos else None, ) # remove <sos> and <eos> y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps] for i, y_hat in enumerate(y_hats): y_true = ys_pad[i] seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1] seq_true = [ self.char_list[int(idx)] for idx in y_true if int(idx) != -1 ] seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ") seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "") seq_true_text = "".join(seq_true).replace(self.recog_args.space, " ") hyp_words = seq_hat_text.split() ref_words = seq_true_text.split() word_eds.append(editdistance.eval(hyp_words, ref_words)) word_ref_lens.append(len(ref_words)) hyp_chars = seq_hat_text.replace(" ", "") ref_chars = seq_true_text.replace(" ", "") char_eds.append(editdistance.eval(hyp_chars, ref_chars)) char_ref_lens.append(len(ref_chars)) wer = ( 0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens) ) cer = ( 0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens) ) alpha = self.mtlalpha if alpha == 0: self.loss = self.loss_att loss_att_data = float(self.loss_att) loss_ctc_data_list = [None] * (self.num_encs + 1) elif alpha == 1: self.loss = torch.sum( torch.cat( [ (item * self.weights_ctc_train[i]).unsqueeze(0) for i, item in enumerate(self.loss_ctc_list) ] ) ) loss_att_data = None loss_ctc_data_list = [float(self.loss)] + [ float(item) for item in self.loss_ctc_list ] else: self.loss_ctc = torch.sum( torch.cat( [ (item * self.weights_ctc_train[i]).unsqueeze(0) for i, item in enumerate(self.loss_ctc_list) ] ) ) self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_att loss_att_data = float(self.loss_att) loss_ctc_data_list = [float(self.loss_ctc)] + [ float(item) for item in self.loss_ctc_list ] loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report( loss_ctc_data_list, loss_att_data, acc, cer_ctc_list, cer, wer, loss_data, ) else: logging.warning("loss (=%f) is not correct", loss_data) return self.loss
[docs] def scorers(self): """Get scorers for `beam_search` (optional). Returns: dict[str, ScorerInterface]: dict of `ScorerInterface` objects """ return dict(decoder=self.dec, ctc=CTCPrefixScorer(self.ctc, self.eos))
[docs] def encode(self, x_list): """Encode feature. Args: x_list (list): input feature [(T1, D), (T2, D), ... ] Returns: list encoded feature [(T1, D), (T2, D), ... ] """ self.eval() ilens_list = [[x_list[idx].shape[0]] for idx in range(self.num_encs)] # subsample frame x_list = [ x_list[idx][:: self.subsample_list[idx][0], :] for idx in range(self.num_encs) ] p = next(self.parameters()) x_list = [ torch.as_tensor(x_list[idx], device=p.device, dtype=p.dtype) for idx in range(self.num_encs) ] # make a utt list (1) to use the same interface for encoder xs_list = [ x_list[idx].contiguous().unsqueeze(0) for idx in range(self.num_encs) ] # 1. encoder hs_list = [] for idx in range(self.num_encs): hs, _, _ = self.enc[idx](xs_list[idx], ilens_list[idx]) hs_list.append(hs[0]) return hs_list
[docs] def recognize(self, x_list, recog_args, char_list, rnnlm=None): """E2E beam search. :param list of ndarray x: list of input acoustic feature [(T1, D), (T2,D),...] :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ hs_list = self.encode(x_list) # calculate log P(z_t|X) for CTC scores if recog_args.ctc_weight > 0.0: if self.share_ctc: lpz_list = [ self.ctc[0].log_softmax(hs_list[idx].unsqueeze(0))[0] for idx in range(self.num_encs) ] else: lpz_list = [ self.ctc[idx].log_softmax(hs_list[idx].unsqueeze(0))[0] for idx in range(self.num_encs) ] else: lpz_list = None # 2. Decoder # decode the first utterance y = self.dec.recognize_beam(hs_list, lpz_list, recog_args, char_list, rnnlm) return y
[docs] def recognize_batch(self, xs_list, recog_args, char_list, rnnlm=None): """E2E beam search. :param list xs_list: list of list of input acoustic feature arrays [[(T1_1, D), (T1_2, D), ...],[(T2_1, D), (T2_2, D), ...], ...] :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() ilens_list = [ np.fromiter((xx.shape[0] for xx in xs_list[idx]), dtype=np.int64) for idx in range(self.num_encs) ] # subsample frame xs_list = [ [xx[:: self.subsample_list[idx][0], :] for xx in xs_list[idx]] for idx in range(self.num_encs) ] xs_list = [ [to_device(self, to_torch_tensor(xx).float()) for xx in xs_list[idx]] for idx in range(self.num_encs) ] xs_pad_list = [pad_list(xs_list[idx], 0.0) for idx in range(self.num_encs)] # 1. Encoder hs_pad_list, hlens_list = [], [] for idx in range(self.num_encs): hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx]) hs_pad_list.append(hs_pad) hlens_list.append(hlens) # calculate log P(z_t|X) for CTC scores if recog_args.ctc_weight > 0.0: if self.share_ctc: lpz_list = [ self.ctc[0].log_softmax(hs_pad_list[idx]) for idx in range(self.num_encs) ] else: lpz_list = [ self.ctc[idx].log_softmax(hs_pad_list[idx]) for idx in range(self.num_encs) ] normalize_score = False else: lpz_list = None normalize_score = True # 2. Decoder hlens_list = [ torch.tensor(list(map(int, hlens_list[idx]))) for idx in range(self.num_encs) ] # make sure hlens is tensor y = self.dec.recognize_beam_batch( hs_pad_list, hlens_list, lpz_list, recog_args, char_list, rnnlm, normalize_score=normalize_score, ) if prev: self.train() return y
[docs] def calculate_all_attentions(self, xs_pad_list, ilens_list, ys_pad): """E2E attention calculation. :param List xs_pad_list: list of batch (torch.Tensor) of padded input sequences [(B, Tmax_1, idim), (B, Tmax_2, idim),..] :param List ilens_list: list of batch (torch.Tensor) of lengths of input sequences [(B), (B), ..] :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) multi-encoder case => [(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)] 3) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray or list """ self.eval() with torch.no_grad(): # 1. Encoder if self.replace_sos: tgt_lang_ids = ys_pad[:, 0:1] ys_pad = ys_pad[:, 1:] # remove target language ID in the beginning else: tgt_lang_ids = None hs_pad_list, hlens_list = [], [] for idx in range(self.num_encs): hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx]) hs_pad_list.append(hs_pad) hlens_list.append(hlens) # 2. Decoder att_ws = self.dec.calculate_all_attentions( hs_pad_list, hlens_list, ys_pad, lang_ids=tgt_lang_ids ) self.train() return att_ws
[docs] def calculate_all_ctc_probs(self, xs_pad_list, ilens_list, ys_pad): """E2E CTC probability calculation. :param List xs_pad_list: list of batch (torch.Tensor) of padded input sequences [(B, Tmax_1, idim), (B, Tmax_2, idim),..] :param List ilens_list: list of batch (torch.Tensor) of lengths of input sequences [(B), (B), ..] :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: CTC probability (B, Tmax, vocab) :rtype: float ndarray or list """ probs_list = [None] if self.mtlalpha == 0: return probs_list self.eval() probs_list = [] with torch.no_grad(): # 1. Encoder for idx in range(self.num_encs): hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx]) # 2. CTC loss ctc_idx = 0 if self.share_ctc else idx probs = self.ctc[ctc_idx].softmax(hs_pad).cpu().numpy() probs_list.append(probs) self.train() return probs_list