一言メモ
- GANを用いたテキスト生成モデル。GeneratorはLSTM、DiscriminatorはTextCNN。
- FeatureMatchingを用いて、Gで生成した文集合を本物の文集合に近づけることで文法上正しい文の生成に成功。
- ArgmaxをSoftmaxの極限で近似するのは面白い工夫。
書誌情報
- リンク
- 著者:Yizhe Zhang, Zhe Gan, Kai Fan, Zhi Chen, Ricardo Henao, Lawrence Carin
- NIPS2016 ✕ 3, ICML ✕ 2を通したデューク大学の化物PhD
- ICML2017(arXiv on 12 Jun 2017)
- NIPS2016 Workshopの進化版
概要
- 文章生成にGANを用いるTextGANを提案した。GeneratorはLSTM、DiscriminatorはTextCNNを利用
- FeatureMatchingと再構成の項を目的関数に追加することで、GANの問題とされているModeCollapse・勾配消失問題を軽減している
- いろいろと学習テクニックを導入しており、中でもSoft-argmax近似は面白い工夫
- Pre-trainingもしっかりしており、SeqGANよりいい定量的評価を実現し、定性的には現実的な文生成に成功している
背景
自然言語生成の系譜
- 文章を生成するタスクは、主に訓練データから確率分布を評価し、その分布からサンプリングする
- 先行研究としては、2014年にRNNベースのAutoEncoder、 2016年にRNNベースのVAEなどがある
- しかし、RNNベースではうまくいかない
— 潜在空間の一部しかカバーできない
— Exposure Bias:RNNでは生成時に一個前に出力を使うため、徐々にズレが蓄積し、文後半ではうまく作れない -
本物らしく作るGenerator vs 偽物を見抜くDiscriminator という2つのモデルを競わせるモデル
— Dは下記式を最大化、Gは最小化するように動く
-
GANの問題点
— ModeCollapsing:潜在変数から同じような結果しか作らなくなる
— Dが局所解に近づいた場合、勾配消失が起きる
モデル:TextGAN
- GはLSTM、DとEはTextCNNを使用
- Feature Machingを採用 [Salimans et al. (2016)]
— TextCNNで本物と偽物の文章群から抽出した特徴ベクトルfのMMDを近づける
— MMD:Gaussianカーネルで再生核ヒルベルト空間(RKHS)へ写像し、 平均の差を用いて一致度を測定 [Gretton et al (2012)] - Reconstruction Errorの追加
— 特徴ベクトルfから元のノイズzを再現する項
損失関数
- 下記式のL_DをDは最大化、 L_GをGは最小化するように学習させる
Discriminator/Encoder
- CNNでSotAを出したTextCNNを採用(元論文)
- 文を学習済みの埋め込み行列でk ✕ Tの行列に変換
- Windowサイズが異なるConvolutionのフィルタをかけ、フィルタ毎にMaxPooling(活性化関数はtanh)
- DはMLPの後にSoftmaxで真偽を判定、EはMLPでzを復元
- Pytorch実装をみるとわかりやすいかも
Generator
- LSTMで潜在ベクトルzからテキストyを作成する
- 一点だけ特殊なのが、zを毎回LSTMにconcatして渡す必要がある点
学習テクニック
本研究では、本物らしい文をGANで作るために様々な工夫を凝らしている
1. Soft-argmax approximation
- テキストの生成では、離散変数を含み、途中でargmaxを挟むため勾配評価が難しかった
- 本研究では下記式で、argmaxをsoftmaxの極限で近似する(L → ∞)
- 実装時はL=1000くらい
2. Pretraining
- G(LSTM)
— CNN-LSTM autoencoderを利用 [Gan et al. (2016)] - D/E(CNN): Permutation training
— テキストの2単語を入れ替えて偽の文を作り学習
— 単語追加・消去より難しいタスク
3. Soft-labeling
- 1 or 0とするのが普通であるが、正解=0.7-1.2、偽=0-0.3からランダムにサンプルする[Salimans et al (2016)]
- 本論文では、最大0.99, 最低0.01としている
4. Compressing Network
- 課題
— GaussianカーネルMMDでは特徴ベクトルfの次元に応じて、 ミニバッチのサイズを大きくする必要がある - 対策:Compressing Network
— 特徴ベクトルfを圧縮するための全結合レイヤーを追加
— 変換後の次元数はデータ効率と表現力のトレードオフ
— 実装では900次元のfを全結合で200次元まで落としている
5. Gaussian covariance matching
- データ効率のためにMMDの式を変更
- カーネルトリックの代わりに共分散と平均を用いて行う
結果
定量的比較
- VAEやseqGANよりいいBLEUスコア
- ただ、今回の生成タスクは翻訳やタスク志向がないのでBLEUでいいのかは微妙
生成文
- 文法的に正しそうな文章が作られるようになった
- 特に括弧やクォーテーションなどがうまくいっている
- 意味も短い文だと正しいが、20単語以上になると少しおかしくなる
- また、Dの学習もよくできており、最終的に偽物の検知精度は95%に達している
潜在特徴空間の起動
- 文Aから文Bまで潜在変数を連続的に変更した際の変化
- AEより意味的にも文法的にも正しいが、完全な連続性ではなく、一部大きな変化が起きる
発表スライド
東大のDL輪読会で発表したスライドです。よろしければ参考にしてください!