3.2 DCGAN

DCGAN の概要

DCGAN (Deep Convolutional Generative Adversarial Network) は、生成器と識別器に 畳み込みネットワーク(Convolutional Neural Network, CNN) を利用した GAN の発展版である。

データセットの作成

torchvision.transforms.Normalize(mean,std): Tensor画像を正規化する。PIL画像には対応していない。 すなわち (C, H, W) 形式の画像を扱う。 平均 mean と標準偏差 std はチャネルの個数分だけ与える。 mean=(mean${}_1$,...,mean${}_n$), std=(std${}_1$, ..., std${}_n$)

[自分へのメモ] MNISTデータのダウンロードに問題が生じているようだ。 2章2.1の記述のように、ブラウザでhttp://www.di.ens.fr/~lelarge/MNIST.tar.gz をダウンロードして、展開して./mnist_root/MNIST/のように配置してから、以下の処理を行うべし。

ネットワークの定義

torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)

What are deconvolutional layers?

torch.nn.BatchNorm2d(num_features)
$\displaystyle y = \frac{x - E[x]}{\sqrt{Var[x] + \epsilon}} \times \gamma + \beta$
$E[x]$: 平均、$Var[x]$: 分散, $\epsilon$: 0 divide を避けるための小さい数、 $\gamma, \beta$: learnable parameters

torch.nn.Module.apply(fn): サブモジュール.children()の各要素にfn(Modle → None)を適用する。

weight.data.normal_(mean, std): mean=平均、std=標準偏差

torchsummary

ネットワークのアーキテクチャを確認するライブラリ

torchsummary.summary(model, input_size=(channels, H, W))