博士以前

人間です

Julia でLDAの周辺化ギブスサンプリングを実装した

トピックモデルとLDA

文書のクラスタリング手法として Latent Dirichlet Allocation (LDA) というものがある。クラスタリングというのは何らかの基準に応じてデータを(教師なしで)分類することで、LDAはトピックモデルという、各文書の潜在トピックに応じてクラスタリングを行うモデルの一つである (Blei et. al., 2003)。潜在トピックというのは、例えばニュース記事なら「政治」や「経済」「スポーツ」というものを想像すればよい。それぞれのトピックに応じて出やすい単語があるはずだし、またトピックに関係なく出やすい単語(日本語なら助詞、英語なら I, we や前置詞など)もあるはずである。また文書のクラスタリングとはいったものの、別に単語の意味に頼って解析をするわけではないので、離散的なデータになら広く応用できるのも利点の一つである。

今回は SparseLDA (Yao et. al., 2009) に基づいて、LDAの効率的なギブスサンプリングのアルゴリズムを Julia 言語を使って実装してみたので、紹介しようと思う。

Github リポジトリはこちら

github.com

LDAとは

LDAは一つの文書が複数のトピックからなっているとする混合モデルである。 ある文書のある箇所にどんな単語が入るかは次のように決まる:

  1. その文書に固有のトピック分布からその箇所のトピックを確率的に抽出する
  2. そのトピックにおける単語分布から確率的に単語を抽出する

これは二種類の(一般には)いびつな形の多面サイコロがあって、まず一つ目のサイコロを振るとトピックが決まる、次にそのトピックに応じたサイコロを選び、それを振ることで単語が決まる、これを繰り返すことで文書が出来上がるモデルだと考えれば良い。 さらにLDAはこの「サイコロ」自体もある確率分布から生成されたとモデルする。

数学的に言えば、トピックと単語はそれぞれあるパラメータを持つ多項分布によって生成され、その多項分布のパラメーター自体はさらに別の確率分布(ディリクレ分布)から生成される、ということである。 詳しい説明は以前紹介した教科書を読んでもらえると良いと思うし、いろんな人がブログで解説をしたりもしている。

predoc.hatenablog.com

通常我々が持っているのは実際の文書データだけで、トピックが何であるかはわからない(だから潜在, latent という)。 LDAを使って解析するというのは、与えられた文書の単語データから潜在するトピックを推定するということである。これはベイズの定理を使って、likelihood  p(w|t) ( t はトピックで  w は単語) から事後分布  p(t|w) を計算することに対応する。

周辺化ギブスサンプリング

LDAの事後分布は(多くのベイズ学習でそうであるように)解析的には計算できないため、色々な近似推論が行われる。 ここでは周辺化ギブスサンプリング (Collapsed Gibbs Sampling, CGS) というマルコフ連鎖モンテカルロ法(MCMC)を使った学習を紹介する。 モンテカルロ法は実際に分布から何回も何回もサンプリングをして、そのヒストグラムによって確率分布を近似しようという手法である。

LDAの事後分布は各文書の各単語に付随するトピックという膨大な数の変数を持っているので、通常のモンテカルロ法のように変数全てを一度にサンプルするのは現実的ではない。 代わりにCGSではどれか一つの変数に着目して、それ以外の変数は既知であるとして一変数についてだけのサンプリングを繰り返す。 LDAでは興味のない潜在変数を全て積分した後で(これを周辺化という)、単語  w が観測された時にそれがトピック  z=t を持つ確率は


\displaystyle
 p(z=t|w) \propto \frac{\alpha_t + n_{t|d}}{\sum_t \alpha_t + n_d} \frac{\beta +n_{w|t}}{\beta V + n_t} \equiv U(t)

となる。ここで

  •  \alpha_t,  \beta :ディリクレ事前分布のパラメーター
  • n_{t|d} :文書  d 中のトピック  t の数
  •  n_{w|t} :トピック  t 中の単語  w の数
  •  n_t :トピック  t 中の全単語数
  •  n_d :文書  d 中の全単語数

この式を見るとLDAが何をやっているかをなんとなく掴むことができる。大雑把に、積の第一項が文書 d でトピック t が選ばれる確率を、第二項がトピック  t で単語  w が選ばれる確率を表しているだろう。 もし  \alpha_t = \beta = 0 ならそれぞれの確率は単に実際に観測された割合、 n_{t|d}/n_d n_{w|t}/n_t になる。 これはつまり最尤推定をしているということで、本当の確率が小さく、またデータ数が少なかったために偶然 n_{t|d}=0 n_{w|t}=0 になった場合にそれを過剰にフィットして、ゼロを返してしまう。 ディリクレ分布のパラメーター  \alpha_t \beta はそれを smoothing して過学習を防ぐ役割をしていると考えられる。

実際のサンプリングでは  Z=\sum_{t=1}^{t=T} U(t) を計算して、乱数  u 0-Z で生成し、 \sum_{t=1}^{t-1}U(t) \lt u \lt \sum_{t}^{T}U(t) となる  t を探す、などの操作をすると思う。 この計算は我々が事前に固定するパラメーターである全トピック数  T に対して線形で、訓練用の全ての文書の全ての単語について毎回行われるため、トピック数が増えると計算量がどんどん増えてしまうことが想像できるだろう。

SparseLDA

SparseLDAというのは Yao 達が提案した、サンプリングを効率的に行うアルゴリズムで、特にトピック数に対して線形に計算が増えないという点で優れている。 まず着目するのは上のサンプリング式で、それを以下のように分解する:


\displaystyle
p(z=t|w) \propto  \frac{\alpha_t\beta}{\beta V + n_t} +n_{t|d} \frac{\beta}{\beta V + n_t} + n_{w|t} \frac{\alpha_t + n_{t|d}}{\beta V + n_t}

単に分子を展開しただけである。*1 サンプリングのためにこれのトピックに関する和を取る:


\displaystyle
Z= \sum_{t} \frac{\alpha_t\beta}{\beta V + n_t} +\sum_{t}n_{t|d} \frac{\beta}{\beta V + n_t} + \sum_{t}n_{w|t} \frac{\alpha_t + n_{t|d}}{\beta V + n_t}

この形にすると以下のような戦略を取れば良さそうだということがわかる:

  • 第一項は文書にも単語にも依存しない量なので、一回のバッチ学習を通して同じ値を使うことができる。*2
  • 第二項の和の中の  n_{t|d} の係数も文書や単語に依存しないので同様に同じ値を使い続けられる。また和自体は一つの文章で共通して使うことができる。
  •  n_{t|d} は実際の学習ではかなり sparse になるベクトルなので、ノンゼロになる成分だけを保持しておけば  t の和は比較的少ない計算で済む。
  • 第三項の和の中の  n_{w|t} の係数は文書には依存するが単語には依存しないので、一つの文書を通じて同じ量を使い続けることができる。また  n_{t|d} は sparse な量なので、 n_{t|d} = 0 とした量をバッチ学習の初めに作っておき、 n_{t|d} がノンゼロになる部分だけを文書ごとに更新すればさらに効率化できる。
  • また  n_{w|t} も非常に sparse なベクトルなので、これもノンゼロになる成分だけを保持しておくようにする。

要は、単語トピックを更新する際に、 t について 毎回1 から順々に足していくのは非常に非効率なので、キャッシュできる量はしておく、sparse なベクトルについてはノンゼロ成分だけで和を取るようにしましょうということ。

 Z を計算したら、次に乱数を生成してどのトピックになるか調べることになる。実際の学習では上の式の第三項の重みが他二つに比べて桁で大きいのでまずはそこから見るようにする。またその際に n_{w|t} を降順にソートしておくことで高速化を図ることもでき、Yao 達の論文では、そのために  n_{w|t} をbitwise で持つようにしようという手法も提案している。*3 これ以上の詳細は論文を読んでほしい。

感想

夏のインターンで取り組んだので、どこかにまとめておこうと思いながら冬になってしまいました。*4なんとか年内に終わってよかったです。 LDAは実装も手軽で、典型的なベイズ学習を勉強するのに適していると思います。 また SparseLDA のように単に式を展開するだけで高速な計算ができるようになるというのも非常に勉強になりました。 次はディリクレ過程を勉強したい。

*1: \sum_t \alpha_t + n_d t に無関係なので無視

*2:同じ値と言っても、ギブスサンプリングなので今更新しようとしている単語の寄与は、毎回引いておく必要がある。この計算量はトピック数が増えても変わらない。以後「同じ値」とか「共通して使える」とか言った時は同様の意味である。

*3:実装が悪いのか、そうしても僕のコードでは siginificant な improvement は無かったが......

*4:多くの研究論文がそうであるように、このノートが念頭に置いているメインの読者は将来の自分である。人は忘れる生き物なので......