FreeStyleWiki

抽象型(Abstractive)要約の機械学習モデル詳細

このエントリーをはてなブックマークに追加

[機械学習,文章要約,論文読みメモ]

抽象型(Abstractive)要約の機械学習モデル詳細

  全体の概要

学習処理の流れ

  1. INPUTの文章(d1,d2,d3...)
  2. Extractor(抽出された文章(d1,d2,d3...)を作成) class PtrExtractSumm(nn.Module)
  3. Abstractor(要約された文章(s1,s2,s3...)を作成) class CopySumm(Seq2SeqSumm)
  • 学習はすべて一気通貫で実施するのではなく、Extractor、Abstractor、それらを統合する強化学習の順で実施する

復号処理(学習済みモデルの実行)の流れ

beam=1 の場合

  1. INPUTの文章(d1,d2,d3...)
  2. Extractor(抽出された文章(d1,d2,d3...)を作成) class RLExtractor(object)
  3. Abstractor(要約された文章(s1,s2,s3...)を作成) class Abstractor(object)

前提知識

実装レベルの詳細,数式

  • 以下、概要をソースコードレベルにブレークダウンする
    • 注意:機械学習のコードなので機械学習のモデルのみが解説される、処理自体は右から左にデータが流れるだけだからだ

  Extractor

学習の設定

+ train_extractor_ml.py

  • Extractorの実装は2つあった(どちらもpytorchのtorch.nn.Module)を継承している
    • ExtractSumm ... ff: 順伝播型ニューラルネットワーク
    • PtrExtractSumm ... rnn: 回帰型ニューラルネットワーク
  • PyTorchでMNISTする が参考になる
    • 要は、torch.nn.Moduleを継承してforwardをオーバーライドしておけば、forwardに定義した通り学習データの学習ができる
    • pytorch側の構造を見てみると、典型的なTemplate Methodパターンで実装されている
    • Pythonでcallを実装すると、クラス(引数)のような呼び出しで処理が実行できる、それをTemplate Methodパターンの呼び出し実装で使っている

+ PyTorchによる学習部分の実装

  • Extractor部分の流れ
    def forward(self, article_sents, sent_nums, target):

        # 畳み込みエンコーダーがそれぞれの文節rjとして処理
        # RNNエンコーダーが隠し層hjを処理する
        enc_out = self._encode(article_sents, sent_nums)
        bs, nt = target.size()
        d = enc_out.size(2)
        ptr_in = torch.gather(
            enc_out, dim=1, index=target.unsqueeze(2).expand(bs, nt, d)
        )

        # RNNデコーダーが隠し層jtをタイムステップtで処理する
        output = self._extractor(enc_out, sent_nums, ptr_in)
        return output

畳み込みエンコーダーがそれぞれの文節rjとして処理

PyTorchでLSTMをする際、食わせるインプットデータは3次元のテンソルある必要があります。具体的には、文章の長さ × バッチサイズ × ベクトル次元数 となっています。

Convolutional word-level sentence encoder

w/ max-over-time pooling, [3, 4, 5] kernel sizes, ReLU activation

 単語レベル、文の畳み込みエンコーダー
 pooling層の設定は"max-over-time"、カーネルサイズ [3, 4, 5]、活性関数はReLU
 // nn.Moduleを継承し、forwardを実装する
 // 初期化は以下のようにされている
 ConvSentEncoder(
   vocab_size,    # 語彙のサイズ
   emb_dim,       # word2vecを使って作られたベクトルの次元
   conv_hidden,   # 隠し層のサイズ
   dropout        # dropoutさせる確率、デフォルト=0.0なので使われていない
 )
class ConvSentEncoder(nn.Module):

    def __init__(self, vocab_size, emb_dim, n_hidden, dropout):
        super().__init__()

        # 畳み込みニューラルネットワークを作る
        # https://pytorch.org/docs/master/generated/torch.nn.Conv1d.html
        #
        # 入力チャンネル:word2vecを使って作られたベクトルの次元
        # 出力チャンネル数:隠し層のサイズ
        # カーネルサイズ:[3, 4, 5]
        self._convs = nn.ModuleList([nn.Conv1d(emb_dim, n_hidden, i) for i in range(3, 6)])
        ...


    def forward(self, input_):
        # input_ は "文章の長さ × バッチサイズ × ベクトル次元数" のテンソルになっているはず
        emb_input = self._embedding(input_)
        # dropoutは学習の際の過学習を防ぐために訓練データをランダムで0を混ぜる設定
        conv_in = F.dropout(emb_input.transpose(1, 2), self._dropout, training=self.training)

        # torch.cat: 指定された次元でデータを結合(dim=1なので配列になる)
        # 3層の畳み込みNNにデータを入力して活性関数ReLUで取り出してそのmax値を取り出して結合
        # 
        output = torch.cat([F.relu(conv(conv_in)).max(dim=2)[0] for conv in self._convs], dim=1)
        return output

RNNエンコーダーが隠し層hjを処理する

ドキュメントのj番目の文のhjとして示される、同じドキュメント内の過去および将来のすべての文のコンテキストを考慮した強力な表現を学習できます

とあるが、説明がほとんどない。よくよくコードを読んでみるとLSTMを使って分類しているだけのようだ。以下がポイント

  • 双方向LSTMを使う
  • many to manyのタスクなので、LSTMの第1戻り値を使用する
class LSTMEncoder(nn.Module):

    def __init__(self, input_dim, n_hidden, n_layer, dropout, bidirectional):
        super().__init__()
        ...
        # input_dim: 各時刻における入力ベクトルのサイズ
        # n_hidden: LSTMの隠れ層ベクトルのサイズ
        # dropoutは使われてない、bidirectionalは双方向LSTMかどうかを定義
        # ここの定義はLSTMの隠れ層を定義しているようだ
        self._lstm = nn.LSTM(input_dim, n_hidden, n_layer, dropout=dropout, bidirectional=bidirectional)


    // 順伝播のコード
    def forward(self, input_, in_lens=None):
        """ [batch_size, max_num_sent, input_dim] Tensor"""

        size = (self._init_h.size(0), input_.size(0), self._init_h.size(1))

        init_states = (self._init_h.unsqueeze(1).expand(*size), self._init_c.unsqueeze(1).expand(*size))

        # many to manyのタスクなので、第1戻り値を使っている
        # torch.nn.LSTMのoutputはoutput,(h_n, c_n) = torch.nn.LSTMという形式
        lstm_out, _ = lstm_encoder(input_, self._lstm, in_lens, init_states)

        return lstm_out.transpose(0, 1)

LSTMの結果を返す前に中のデータをいじっているように見える。それは以下で

functional LSTM encoder (sequence is [b, t]/[b, t, d], lstm should be rolled lstm)

 関数的なLSTMエンコーダー(シーケンスは[b、t] / [b、t、d]、 LSTMは展開されたLSTMである必要があります)
  • LSTMの単語以外何もわからねえ
    • [b、t] / [b、t、d]
    • 展開されたLSTM
  • コードは適宜整形している
// model/rnn.py#L9
def lstm_encoder(sequence, lstm, seq_lens=None, init_states=None, embedding=None):

    """ functional LSTM encoder (sequence is [b, t]/[b, t, d], lstm should be rolled lstm)"""

    batch_size = sequence.size(0)

    // LSTMに`batch_first=True`が設定されていると、LSTMの引数となるテンソルは「バッチサイズ × 文章の長さ × ベクトル次元数」となる
    // なのでここで転置をかけている、まあ不要
    if not lstm.batch_first:
        ...

    // こちらも使われていなそう
    if seq_lens:
        ...

    if init_states is None:
        // ここも、どうやらLSTMの初期状態が引数で渡されないことを想定している。不要。
        device = sequence.device
        init_states = init_lstm_states(lstm, batch_size, device)
    else:
        init_states = (init_states[0].contiguous(), init_states[1].contiguous())

    if seq_lens:
        ...
    else:
        // 実際にLSTMのモデルにデータを渡す
        // また nn.Moduleの __call__を使った呼び出し、forwardを実行
        lstm_out, final_states = lstm(emb_sequence, init_states)


    // torch.nn.LSTMのoutputはoutput,(h_n, c_n)
    // 結局呼び出し元で final_statuesは捨てられている
    return lstm_out, final_states

RNNデコーダーが隠し層jtをタイムステップtで処理する

  • RNNデコーダーが隠し層jtをタイムステップtで処理する: 実装(LSTMPointerNet)
  • Pointer Networks を使う
    • 元論文のサマリー 2015: Pointer Networks #19
    • インプットが可変のデータを処理する際に、組合せ爆発を起こさずうまいこと固定長で出力させるための仕組みっぽい。要約文は原文に対して固定の文字列で返したいからか。
    • このWiki内でのまとめは → Pointer network

  Abstractor

  • 読んでいる限りここが一番重要なステップになりそう
    def forward(self, article, art_lens, abstract, extend_art, extend_vsize):
        attention, init_dec_states = self.encode(article, art_lens)
        mask = len_mask(art_lens, attention.device).unsqueeze(-2)
        logit = self._decoder(
            (attention, mask, extend_art, extend_vsize),
            init_dec_states, abstract
        )
        return logit

RNNエンコーダーが隠し層hjを処理する

RNNデコーダーが隠し層stを処理する

  強化学習

  復号処理(学習済みモデルの実行)