FreeStyleWiki

抽象型(Abstractive)要約の復号処理

このエントリーをはてなブックマークに追加

[機械学習,文章要約,論文読みメモ]

抽象型(Abstractive)要約の復号処理

  復号処理(学習済みモデルの実行)の概要

  • abstractorの使われ方が学習時と違う
    • beam=1 の場合
  1. INPUTの文章(d1,d2,d3...)
  2. Extractor(抽出された文章(d1,d2,d3...)を作成) class RLExtractor(object)
  3. Abstractor(要約された文章(s1,s2,s3...)を作成) class Abstractor(object)

  Abstractorのdecode部分に関連する理論について

論文の後半部分にモデルの数式についての詳細が載っている。

Sequence-Attention-Sequence Model

We use a standard encoder-aligner-decoder model (Bahdanau et al., 2015; Luong et al., 2015)

with the bilinear multiplicative attention function (Luong et al., 2015),

fatt(hi, zj ) = h > i Wattnzj , for the context vector ej .

We share the source and target embedding matrix Wemb as well as output projection matrix as in Inan et al. (2017); Press and Wolf (2017); Paulus et al. (2018).

Copy Mechanism

We add the copying mechanism as in See et al. (2017) to extend the decoder to predict over the extended vocabulary of

words in the input document. A copy probability

Seq2seqAttention Mechanismを追加したモデル

我々は標準のencoder-aligner-decoderモデルに

バイリニアAttention

バイリニアattentionを使います

\( f_{att}(h_i, z_j) = h^{T}_i W_{attn} z_j \)

我々は要約元文章と要約文で文字列の埋め込み表現ベクタ \( W_{emb} \) を共有すると同時に、出力の投射行列でもそれを共有します。

コピー機構

We add the copy mechanism to help directly copy some outof-vocabulary (OOV) words

  • OOVは自然言語処理に古くからある課題の1つで、学習データに出現しない単語(未知語)は予測時には使用することができないというものです。
  • Pointer Networksが行っていたような入力から出力を選ぶという機構を使えば、たとえ学習データの語彙に含まれていなくても入力からコピーしてくることでOOVの問題を緩和することができます

とある

  ソースコード

コールスタックは以下のようになる

CopySumm#_prepro

    def _prepro(self, raw_article_sents):
        ext_word2id = dict(self._word2id)
        ext_id2word = dict(self._id2word)
        for raw_words in raw_article_sents:
            for w in raw_words:
                if not w in ext_word2id:
                    ext_word2id[w] = len(ext_word2id)
                    ext_id2word[len(ext_id2word)] = w

        # 以下のようなデータが
        # raw_article_sents = [
        #     ["How", "do", "you", "think?"],
        #     ["I", "don't", "have", "any", "idea"],
        # ]
        # word2vecのID表現に変換される
        # raw_article_sents = [
        #     [78, 129, 999, 345],
        #     [21, 593, 903, 68, 393],
        # ]
        articles = conver2id(UNK, self._word2id, raw_article_sents)

        # articles のリストの要素数
        # art_lens = [4, 5, ...]
        art_lens = [len(art) for art in articles]

        # articlesを1文の最大長で揃えて, paddingした Tensorに変換する
        article = pad_batch_tensorize(articles, PAD, cuda=False).to(self._device)

        # 入力文字列を取り込むために同じデータを作る(ただし入力から作った語彙ext_word2idを使う)
        extend_arts = conver2id(UNK, ext_word2id, raw_article_sents)
        extend_art = pad_batch_tensorize(extend_arts, PAD, cuda=False).to(self._device)
        extend_vsize = len(ext_word2id)

        # 次の処理に渡すためtupleに変換
        dec_args = (article, art_lens, extend_art, extend_vsize, START, END, UNK, self._max_len)
        return dec_args, ext_id2word

CopySumm#batch_decode

  • article: 記事のTensor(h, w) = (行数, 1文の最大長), paddingされており、中身はword2vecのID表現
  • art_lens: 記事ごとの単語数のlist, size=h
  • extend_art: 記事のTensor(h, w) = (行数, 1文の最大長), paddingされており、中身はword2vecのID表現(入力から作った語彙を含む)
  • extend_vsize: 入力文字列を取りこんだ辞書のサイズ
  • go, eos, unk: それぞれ文章開始,文章終了,不明語を表す埋め込み表現のID
  • max_len: 出力する要約の最大文字数
    def batch_decode(self, article, art_lens, extend_art, extend_vsize, go, eos, unk, max_len):
        # 読み込んだ文章のサイズ、ここではいくつ文章があるか
        batch_size = len(art_lens)
        # 語彙サイズ
        vsize = self._embedding.num_embeddings

        # "attention"を取得
        # attention のサイズは article = (h, w) であれば
        # attention = (h, w, 256) になる
        attention, init_dec_states = self.encode(article, art_lens)

        mask = len_mask(art_lens, attention.device).unsqueeze(-2)
        attention = (attention, mask, extend_art, extend_vsize)
        tok = torch.LongTensor([go]*batch_size).to(article.device)
        outputs = []
        attns = []
        states = init_dec_states
        for i in range(max_len):
            tok, states, attn_score = self._decoder.decode_step(tok, states, attention)
            attns.append(attn_score)
            outputs.append(tok[:, 0].clone())
            tok.masked_fill_(tok >= vsize, unk)
        return outputs, attns

Seq2SeqSumm#encode

    def encode(self, article, art_lens=None):
        # article: 記事のTensor(h, w) = (行数, 1文の最大長)
        # size = (2, h, 256)
        size = (
            self._init_enc_h.size(0),
            len(art_lens) if art_lens else 1,
            self._init_enc_h.size(1)
        )
        # init_enc_states = (Tensor(2, h, 256) , Tensor(2, h, 256))
        init_enc_states = (
            self._init_enc_h.unsqueeze(1).expand(*size),
            self._init_enc_c.unsqueeze(1).expand(*size)
        )

        # enc_art      = Tensor(w, h, 512)  <-- 転置している
        # final_states = ( Tensor(2, h, 256), Tensor(2, h, 256) )
        enc_art, final_states = lstm_encoder(
            article, self._enc_lstm, art_lens, init_enc_states, self._embedding
        )

        if self._enc_lstm.bidirectional:
            h, c = final_states
            final_states = (
                torch.cat(h.chunk(2, dim=0), dim=2),
                torch.cat(c.chunk(2, dim=0), dim=2)
            )
        init_h = torch.stack([self._dec_h(s)
                              for s in final_states[0]], dim=0)
        init_c = torch.stack([self._dec_c(s)
                              for s in final_states[1]], dim=0)
        init_dec_states = (init_h, init_c)
        attention = torch.matmul(enc_art, self._attn_wm).transpose(0, 1)
        init_attn_out = self._projection(torch.cat(
            [init_h[-1], sequence_mean(attention, art_lens, dim=1)], dim=1
        ))
        return attention, (init_dec_states, init_attn_out)

AttentionalLSTMDecoder#decode_step

CopyLSTMDecoder#_step

step_attention

dot_attention_score

    # こちらは継承元(AttentionalLSTMDecoder)のdecode_stepを使用する
    def decode_step(self, tok, states, attention):
        # tok:       [2,2,2...], size=h この時点ではINPUTの文の行数サイズのlist, 2で埋められている
        # states:    (init_dec_states, init_attn_out)
        # attention: (h, w, 256)
        # logit.shape => torch.Size([128 (LSTMの隠し層のサイズ) , 30000 (辞書サイズ) + copy機能で追加された単語数])

        logit, states, score = self._step(tok, states, attention)

        # 最大値のindexだけ取得して返している
        # logit は 以下のような構成
        #
        # 横のサイズが辞書と同じ(=辞書の単語を表す)
        # 縦のサイズはLSTMの隠し層のサイズと同じ(=入力単語のインデックスを表す)
        # [0.001, 0.002, 0.001, ... 0.6]
        # [0.001, 0.002, 0.001, ... 0.6]
        # [0.001, 0.002, 0.001, ... 0.6]
        # [0.001, 0.002, 0.001, ... 0.6]
        # [0.001, 0.002, 0.9  , ... 0.6]
        #  ...
        # [0.001, 0.002, 0.001, ... 0.6]
        # torch.max で1行の中で一番確率が高い単語が取り出せる、つまり入力単語に対応する最適な単語が出せる
        out = torch.max(logit, dim=1, keepdim=True)[1]

        # ここで縦ベクトルになっている
        # 格納されているindexを単語に直すと以下のような感じだった
        # ['誘拐', '『', 'スピルバーグ', '水谷', '美女', 'キー',...] ^ T  <--  転置の意味
        # torch.maxの返り値は(values, indices) = (最大値, 最大値のindex)
        # https://pytorch.org/docs/stable/generated/torch.max.html

        return out, states, score

    # こちらはCopyLSTMDecoderの_stepを使用する
    def _step(self, tok, states, attention):
        prev_states, prev_out = states
        lstm_in = cat(
            [self._embedding(tok).squeeze(1), prev_out],
            dim=1
        )
        states = self._lstm(lstm_in, prev_states)
        lstm_out = states[0][-1]
        query = mm(lstm_out, self._attn_w)
        attention, attn_mask, extend_src, extend_vsize = attention
        context, score = step_attention(query, attention, attention, attn_mask)
        dec_out = self._projection(cat([lstm_out, context], dim=1))

        gen_prob = self._compute_gen_prob(dec_out, extend_vsize)
        # コピーごとの確率を計算する
        copy_prob = sigmoid(self._copy(context, states[0][-1], lstm_in))

        # add the copy prob to existing vocab distribution
        # 条件付き確率(logit)を返す
        lp = log(
            ((-copy_prob + 1) * gen_prob).scatter_add(
                dim=1,
                index=extend_src.expand_as(score),
                src=score * copy_prob
            ) + 1e-8)  # numerical stability for log

        return lp, (states, dec_out), score

    def step_attention(query, key, value, mem_mask=None):
        """ query[(Bs), B, D], key[B, T, D], value[B, T, D]"""
        score = dot_attention_score(key, query.unsqueeze(-2))
        if mem_mask is None:
            # 論文中で exp (eij) と書かれているところだと思う
            norm_score = softmax(score, dim=-1)
        else:
            norm_score = prob_normalize(score, mem_mask)
        output = attention_aggregate(value, norm_score)
        return output.squeeze(-2), norm_score.squeeze(-2)

    def dot_attention_score(key, query):
        """[B, Tk, D], [(Bs), B, Tq, D] -> [(Bs), B, Tq, Tk]"""
        return query.matmul(key.transpose(1, 2))


StackedLSTMCells#forward

    def forward(self, input_, state):
        """
        Arguments:
            input_: FloatTensor (batch, input_size)
            states: tuple of the H, C LSTM states
                FloatTensor (num_layers, batch, hidden_size)
        Returns:
            LSTM states
            new_h: (num_layers, batch, hidden_size)
            new_c: (num_layers, batch, hidden_size)
        """
        hs = []
        cs = []
        for i, cell in enumerate(self._cells):
            s = (state[0][i, :, :], state[1][i, :, :])
            h, c = cell(input_, s)
            hs.append(h)
            cs.append(c)
            input_ = F.dropout(h, p=self._dropout, training=self.training)

        new_h = torch.stack(hs, dim=0)
        new_c = torch.stack(cs, dim=0)

        return new_h, new_c