FreeStyleWiki

Pointer network

このエントリーをはてなブックマークに追加

[機械学習,数学,文章要約]

Pointer network

論文の概要

    1. 入力要素へのポインタとなる出力の条件付き確率を予測させる学習を行う
    2. これらの問題では出力の辞書が入力の長さに依存しているというような 組合せ最適化 問題に属している
    3. この方式では入力のシーケンスを出力のメンバーとして選び出すためのポインタとして"attention"を用いる
    4. この方式はseq2seqの精度を入力attentionで向上させるだけでなく、出力の辞書サイズを可変にできることを可能にする

  問題設定

例えば凸包(Convex Hull)を求める。

有限集合の凸包は輪ゴムを掛けるようなものである

とあるように、問題自体は簡単で全部の点を囲める、点の集合を求めることだ。

このとき入力が10個の点なので、集合Pとして与えられたら

P+%3D+%5Cleft%5C%7B+P%5F1%2C+%2E%2E%2E%2C+P%5F%7B10%7D+%5Cright%5C%7D+

出力は6個の点なので、各出力は入力の集合上のインデックスとすると

C%5EP+%3D+%5Cleft%5C%7B+2%2C+3%2C+5%2C+6%2C+7%2C+8+%5Cright%5C%7D+

みたいな感じになる(インデックス値は適当です)

  • これをsoftmax(やほかの活性関数)で出力しようと思うのだがsoftmaxは固定長の次元しか持てないのでできない
    • これが概要で書いていた出力の辞書が入力の長さに依存しているという問題

  既存の手法(seq2seq)

seq2seqの式の説明

  • 論文やYouTubeではseq2seqを数式で説明している(というのはこれがパラメトリックモデルだから)
  • パラメトリックモデルは3つのステップでゴールとなるモデル(=数式)を決定する(そしてパラメトリックモデルは確率分布を用いたモデルになる)
    1. パラメーターを含むモデル(数式)を設定する
    2. パラメーターを評価する基準を定める
    3. 最良の評価を与えるパラメーターを決定する

トレーニングデータのペア %28%5Cmathcal+P%2C+%5Cmathcal+C%5E%7B%5Cmathcal+P%7D%29+ が与えられたとき次の条件付き確率をRNN (LSTM)によるパラメトリックモデルで推定するというもの。

パラメーターを含むモデル(数式)は以下になる

p%28%5Cmathcal+C%5E%7B%5Cmathcal+P%7D+%7C+%5Cmathcal+P+%3B+%5Ctheta%29+%3D+%5Cprod%5F%7Bi%3D1%7D%5E%7Bm%28%5Cmathcal+P%29%7D+p%5F%7B%5Ctheta%7D+%28C%5Fi+%7C+C%5F1%2C+%2E%2E%2E%2C+C%5F%7Bi%2D1%7D%2C+%5Cmathcal+P%3B+%5Ctheta%29+

  • 凸包問題なら、全体の座標の集合Pと、答えになる座標の集合C^Pをデータセットとして与えまくって学習させるというわけか

数式が若干わかりにくい

  • p%28%5Cmathcal+C%5E%7B%5Cmathcal+P%7D+%7C+%5Cmathcal+P+%3B+%5Ctheta%29+

条件付き確率 P(A|B) はしばしば「B が起こったときの A の(条件付き)確率」「条件 B の下での A の確率」などと表現される

なので

    • 条件「全体の座標の集合P」の下での「答えになる座標の集合C^P」ぐらいの意味で
    • θは最良の評価を与えるパラメーターであり、学習によって最適化される
  • %5Cprod%5F%7Bi%3D1%7D%5E%7Bm%28%5Cmathcal+P%29%7D+p%5F%7B%5Ctheta%7D+%28C%5Fi+%7C+C%5F1%2C+%2E%2E%2E%2C+C%5F%7Bi%2D1%7D%2C+%5Cmathcal+P%3B+%5Ctheta%29+
    • Пが総乗記号
      • m(P) は出力対象の個数です、忖度してください
      • i=1~m(P)までpθ(A|B)の条件付き確率を掛け算して合計する

と、ここまで読んで思ったが総乗記号で全部の積を求めてしまうと左辺がスカラー値になって実用的じゃないと思った。元のseq2seqの論文では左辺も条件付き確率の配列になっているのでそれが正しいんじゃないかなあ…( ^ω^)。左辺は暗黙的に行列だと言われてしまえばおしまいだが、プログラムならコンパイルエラーになりそうな数式だ。しかしこれは既存手法を紹介しているだけなのでほっとく。

gtexのコマンドで式が出なくなってしまったので最良の評価を与えるパラメーターは省略、参照しているブログや元の論文を参照ねがう。

  既存の手法(Content Based Input Attention)

Pointer Networks より引用

attentionというものを考えて、seq2seqでは固定的であったdecoderのステートに対してより多くの情報を付加する.

%28e%5F1%2C+%2E%2E%2E%2C+e%5Fn%29 LSTM encoderの隠れ状態

%28d%5F1%2C+%2E%2E%2E%2C+d%5F%7Bm%28%5Cmathcal+P%29%7D%29+ LSTM decoderの隠れ状態

としたとき、attentionを以下のように定義する.

  • u%5Fj%5Ei++%3D+v%5ET+%5Ctanh%28W%5F1e%5Fj+%2B+W%5F2d%5Fi%29+j+%5Cin+%281%2C+%2E%2E%2E%2C+n%29+
  • a%5Fj%5Ei++%3D+%7B%5Crm+softmax%7D%28u%5Fj%5Ei%29+j+%5Cin+%281%2C+%2E%2E%2E%2C+n%29+
  • d%5Fi%5E%7B%27%7D++%3D++%5Csum%5F%7Bj%3D1%7D%5En+a%5Fj%5Ei+e%5Fj++
  • %7B%5Crm+hidden+%5C+states%7D+%3D+%7B%5Crm+concat%7D%28d%5Fi%5E%7B%27%7D%2C+d%5Fi%29+

この手法も出力が固定長という問題がある。動画ではさらっと流されているのでここまで。

  Ptr-Net

やっとPtr-Netの登場、先ほどの式のSoftmaxの左辺が条件付き確率になった

  • パラメーターを含むモデル(数式)
    • %28e%5F1%2C+%2E%2E%2E%2C+e%5Fn%29 LSTM encoderの隠れ状態
    • %28d%5F1%2C+%2E%2E%2E%2C+d%5F%7Bm%28%5Cmathcal+P%29%7D%29+ LSTM decoderの隠れ状態
    • u%5Fj%5Ei+%3D+v%5ET+%5Ctanh%28W%5F1e%5Fj+%2B+W%5F2d%5Fi%29+%5C+j+%5Cin+%281%2C+%2E%2E%2E%2C+n%29+
    • p%28C%5Fi+%7C+C%5F1%2C+%2E%2E%2E%2C+C%5F%7Bi%2D1%7D%2C+%5Cmathcal+P%29+%3D+%7B%5Crm+softmax%7D%28u%5Ei%29+

パラメータはこれでわかったのだけどどのように最適化すればいいのだろう?

  • seq2seqが入力要素すべてに対して予測値を出しているのに対して、Ptr-Netは入力値へのポインタを示している