2.2 変分オートエンコーダ

オートエンコーダの入力は画像であったが、変分オートエンコーダの入力は確率変数である。 変分オートエンコーダは生成モデルである。

変分オートエンコーダの概要

変分オートエンコーダ (Variational Autoencoder, VAE) は生成モデルの1つで、 モデル分布 $p_{\theta} (x)$ の尤度が最大となるパラメータを計算する。 VAE は尤度を観測変数 $x$ と潜在変数 $z$ の2つに分解して計算する。 モデル分布 $p_{\theta}(x)$ を$p_{\theta}(z)$ と $p_{\theta}(x | z)$ の2つの確率分布に分解でき、潜在変数 $z$ の積分で計算できる。

$\displaystyle p_{\theta}(x) = \int p_{\theta} (x | z) ~ p_{\theta} (z) $

オートエンコーダは、画像 $x$ の特徴量をエンコーダで隠れ層の潜在変数 $z$ に圧縮し、デコーダで潜在変数を画像 $x'$ に復元した。 潜在変数 $z$ の分布は不明であるため、どのような変数をデコーダに入力するとどのような画像が生成できるかわからない。そのため、オートエンコードは生成モデルとはいえない。

変分オートエンコーダ (VAE) では、潜在変数 $z$ を標準正規分布にしたがう確率変数でモデル化している。このため、学習に使用する入力画像の特徴量を標準正規分布に押し込むことができ、VAEに標準正規分布を入力すると目的の画像を生成できる。 学習が終了するとエンコーダは不要となり、デコーダは標準正規分布に従う確率変数から 画像 $x'$ を確率的に生成できる。そのため VAE は生成モデルであるといえる。

VAE の潜在変数は標準正規分布n従うようにモデル化されているため、さまざまな画像に対する潜在変数が潜在空間内で密集している。 2枚の画像をそれぞれ潜在変数に変換し、その潜在変数を補間した変数を求めることで 2枚の画像を連続的に変化させることができる。

変分オートエンコーダのネットワーク

変分オートエンコーダはエンコーダとデコーダで構成される。 MNIST の数字画像を用いるので、入力画像は 28x28x1=784 次元ベクトルとなる。 エンコーダは10次元の平均ベクトルと分散ベクトルを出力し、これの線形和から10次元の潜在変数ベクトル $z$ を求める。$z$は標準正規分布にしたがう。

VAE クラスは Encoder クラスと Decoder クラスで構成される。 forward 関数の中で、Encoder は平均 meanと分散 var を出力し、 この2つを引数にして reparameterize メソッドで潜在変数 $z$ を計算する。 潜在変数 $z$ は、平均と分散をもとに標準正規分布にしたがう乱数に変換したものである。 最後に、Decoder クラスに潜在変数 $z$ を渡して、生成画像 $y$ を得る。

データセットの作成

[自分へのメモ] MNIST のデータをダウンロードするきに 503 エラーがおきるようになったので注意。 そのため、自分でMNISTデータをダウンロードして展開しておく必要がある。(2.1章 からの転載)

[自分へのメモ] 2.1章が実行済みならば既に以下のフォルダにファイルが展開されている。

./data/mnist/processed/{training, test}.pt
            /raw/t*-{images,labels}-idx*-ubyte

ネットワークの定義

torch.randn(*size) ... 平均0, 分散1の一様分布からの乱数で満たされたTensor を返す。

変分オートエンコーダの損失関数

変分オートエンコーダでは、 潜在変数が標準正規分布に従うように 正規化の誤差を損失関数に含める必要がある。 そのため、VAEの損失関数は 再構成誤差の損失関数 $J^{REC}$ と 正則化の損失関数 $J^{REG}$ の和となる。

$J = J^{REC} + J^{REG}$

$\displaystyle J^{REC} = - \frac{1}{N} \sum_{i=1}^{N} (x_i \log y_i + (1 - x_i) \log (1 - y_i))$

$\displaystyle J^{REG} = - \frac{1}{2} \sum_{j=1}{J} (1 + \log {\sigma}_j^2 - \mu_j^2 - {\sigma}_j^2$

$J^{REG}$ は $\sigma = 1$ かつ $\mu = 0$ のときに最小となる。 つまり、標準正規分布にしたがうと損失が0に近づく。

[自分へのメモ] torch.nn.functions.softplus

$\displaystyle \mbox{Softplus}(x) = \frac{1}{\beta} * \log (1 + e^{\beta * x})$

SoftPlus 関数は ReLU 関数をなめらかに近似した関数である。 出力は必ず正となる。

数値的な安定性のために $input \times \beta > threashold$ の場合は線形関数となる。 デフォルト値は $\beta =1$, $threshold = 20$ である。

学習の実行

画像の生成

squeeze() は、テンソル中のサイズが1の次元を削除する。 detach()は、現在のグラフから切り離された新しいテンソルを生成して返す。

元画像と復元画像