4章 敵対的生成ネットワーク

4.6 WGAN-GP

WGAN の最近の拡張は WGAN-GP (Wasserstein GAN - Gradient Penalty) フレームワークである。

WGAN-GP の Generator は WGAN の Generator と同じやり方で定義、コンパイルされる。 Discriminator の定義とコンパイルが異なる。

  • Discriminator の損失関数に Gradient Penalty (勾配ペナルティ) 項を含める。
  • Discriminator の重みをクリッピングしない。
  • Discriminator 内ではバッチ正規化をしない。

Lipschitz 制約を課す別の方法を提案している。 その方法では、Discriminator の重みをクリッピングするのではなく、 Discriminator の勾配ノルムが1から離れた場合にモデルにペナルティを与える項を損失関数に導入している。

4.6.1 Gradient Penelty Loss (勾配ペナルティ損失)

原画像(本物)と生成画像(偽物)からの Wasserstein 損失と並んで、損失関数に Gradient Penalty Loss を追加した。

Gradient Penalty Loss は、入力画像に関する predict (予測) の gradient (勾配) の norm と 1 との間の差の平方根を計測する。 このモデルは、自然に Gradient Penalty 項が最小化されるような重みを見つけるようになり、 結果として、Lipschitz 制約に従うことになる。

この gradient (勾配) を training 中にいつも計算するのは大変なので、 WGAN-GP はほんの少しの箇所だけで gradient を評価する。 ミキシングのバランスを取るために、本物の画像のバッチと偽の画像のバッチを組にしてつないだ線に沿って ランダムに選ばれた点にある補間画像のセットを用いる。

実装

Large-scale CelebFaces Attributes (CelebA) Dataset https://bit.ly/2WSiOXt から セレブの顔データ 'Align&Cropped Images' https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg を GAN_code/data/celeb/ に用意する必要がある。

ダウンロードについて

現時点(2021/09)では、 Large-scale CelebFaces Attributes (CelebA) Dataset https://bit.ly/2WSiOXtURL にアクセスすると Downloads には以下のデータが示される。

  • ZIP: In-The-Wild Images
  • ZIP: Align&Cropped Images
  • TXT: Landmarks Annotations
  • TXT: Attributes Annotations
  • TXT: Identities Annoations
  • TXT: Train/Val/Test Partitions

このうちの、 Align&Cropped Images にアクセスをするとGoogle Driveの https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg にアクセスすると以下のファイル群がある。

dataset/CelebA
  +--README.txt
  +--Img
  |    +--img_celeba.7z
  |    +--img_align_celeba_png.7z
  |    +--img_align_celeba.zip
  +--Eval
  |    +--list_eval_partition.txt
  +--Anno
       +--list_landmarks_celeba.txt
       +--list_landmarks_align_celeba.txt
       +--list_bbox_celeba.txt
       +--list_attr_celeba.txt
       +--identity_CelebA.txt

このうちの Img/img_align_celeba.zip が jpg ファイルの集まりなので、ダウンロードして 手元のPCの以下のパスに展開した。

  • gtune ... d:/tmp/celb/*.jpg
  • galleria ... d:/tmp/CelebA/img_align_celeba/*.jpg
In [1]:
IMAGE_SIZE = 64
BATCH_SIZE = 64
In [2]:
# GDL_code/utils/loaders.py
# [自習] D:\tmp\CelebA\img_align_celeba\*.jpg から画像ファイルを読みだすように変更した。
# Windowsの場合でもString中の path は '/' または '\\' で区切ることに注意する。
# flow_from_directory で指定するdirectoryは画像ファイルが置いてあるフォルダの一つ上までのパスを指定すること。

import tensorflow as tf
import numpy as np

DATA_PATH = 'D:\\tmp\\CelebA'

data_gen = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=lambda x: (x.astype('float32') - 127.5) / 127.5
)

x_train = data_gen.flow_from_directory(
    directory=DATA_PATH,
    target_size = (IMAGE_SIZE, IMAGE_SIZE),
    batch_size = BATCH_SIZE,
    shuffle = True,
    class_mode = 'input',
    subset='training'
)
Found 202599 images belonging to 1 classes.
In [3]:
# [自習] x_train はnext()を作用させると BATCH_SIZE だけ取り出される。
# class_mode = 'input' を指定しているので、ラベルとしてデータそのものが使われる。
# x = next(x_train) として取り出したxの0次元目のサイズは2である。
# x[0][i] と x[1][j] は同じ画像である。

x = np.array(next(x_train))
print(x.shape)
(2, 64, 64, 64, 3)
In [4]:
# [自習]
print(x[0])
[[[[ 0.62352943  0.67058825  0.56078434]
   [ 0.62352943  0.67058825  0.56078434]
   [ 0.6313726   0.6784314   0.5686275 ]
   ...
   [ 0.6392157   0.654902    0.5686275 ]
   [ 0.7019608   0.70980394  0.64705884]
   [ 0.7647059   0.77254903  0.7254902 ]]

  [[ 0.654902    0.7019608   0.60784316]
   [ 0.64705884  0.69411767  0.6       ]
   [ 0.64705884  0.69411767  0.6       ]
   ...
   [ 0.67058825  0.6862745   0.6       ]
   [ 0.7411765   0.7490196   0.6862745 ]
   [ 0.78039217  0.7882353   0.7411765 ]]

  [[ 0.64705884  0.69411767  0.6156863 ]
   [ 0.64705884  0.69411767  0.6156863 ]
   [ 0.64705884  0.69411767  0.6156863 ]
   ...
   [ 0.7176471   0.73333335  0.64705884]
   [ 0.75686276  0.7647059   0.7019608 ]
   [ 0.7647059   0.77254903  0.7254902 ]]

  ...

  [[ 0.56078434  0.23921569 -0.2       ]
   [ 0.6784314   0.27058825 -0.09803922]
   [ 0.6627451   0.24705882 -0.08235294]
   ...
   [ 0.21568628  0.05098039 -0.09803922]
   [ 0.19215687  0.05882353 -0.06666667]
   [ 0.23921569  0.13725491  0.00392157]]

  [[ 0.48235294  0.12941177 -0.3019608 ]
   [ 0.5921569   0.16862746 -0.19215687]
   [ 0.6313726   0.21568628 -0.11372549]
   ...
   [ 0.33333334  0.14509805 -0.05882353]
   [ 0.3254902   0.12941177 -0.04313726]
   [ 0.24705882  0.09019608 -0.09803922]]

  [[ 0.46666667  0.09803922 -0.30980393]
   [ 0.5137255   0.06666667 -0.27058825]
   [ 0.5137255   0.09803922 -0.21568628]
   ...
   [ 0.27058825  0.03529412 -0.23137255]
   [ 0.30980393  0.04313726 -0.1764706 ]
   [ 0.30980393  0.08235294 -0.15294118]]]


 [[[ 0.79607844  0.79607844  0.8745098 ]
   [ 0.7882353   0.7882353   0.88235295]
   [ 0.8039216   0.8039216   0.8980392 ]
   ...
   [ 0.79607844  0.8039216   0.84313726]
   [ 0.85882354  0.8666667   0.90588236]
   [ 0.85882354  0.85882354  0.9372549 ]]

  [[ 0.79607844  0.79607844  0.8745098 ]
   [ 0.7882353   0.7882353   0.88235295]
   [ 0.8039216   0.8039216   0.8980392 ]
   ...
   [ 0.81960785  0.827451    0.8666667 ]
   [ 0.85882354  0.8666667   0.90588236]
   [ 0.84313726  0.84313726  0.92156863]]

  [[ 0.79607844  0.79607844  0.8745098 ]
   [ 0.8039216   0.8039216   0.8980392 ]
   [ 0.8117647   0.8117647   0.90588236]
   ...
   [ 0.827451    0.8352941   0.8745098 ]
   [ 0.8666667   0.8745098   0.9137255 ]
   [ 0.8352941   0.8352941   0.9137255 ]]

  ...

  [[ 0.15294118 -0.12941177 -0.39607844]
   [-0.14509805 -0.46666667 -0.7176471 ]
   [ 0.05098039 -0.29411766 -0.52156866]
   ...
   [-0.5294118  -0.92156863 -0.92941177]
   [-0.5529412  -0.94509804 -0.85882354]
   [ 0.4117647   0.23137255  0.43529412]]

  [[-0.03529412 -0.31764707 -0.58431375]
   [-0.09803922 -0.41960785 -0.654902  ]
   [-0.2627451  -0.60784316 -0.81960785]
   ...
   [ 0.18431373 -0.21568628 -0.44313726]
   [ 0.1764706  -0.22352941 -0.3882353 ]
   [-0.34117648 -0.60784316 -0.7019608 ]]

  [[-0.16078432 -0.4509804  -0.6784314 ]
   [-0.21568628 -0.5372549  -0.75686276]
   [-0.09803922 -0.44313726 -0.6392157 ]
   ...
   [-0.13725491 -0.5764706  -0.8509804 ]
   [-0.13725491 -0.56078434 -0.84313726]
   [-0.03529412 -0.45882353 -0.7411765 ]]]


 [[[-0.39607844 -0.40392157 -0.56078434]
   [-0.38039216 -0.3882353  -0.54509807]
   [-0.35686275 -0.38039216 -0.5294118 ]
   ...
   [-0.60784316 -0.7254902  -0.8745098 ]
   [ 0.05098039 -0.04313726 -0.23137255]
   [-0.06666667 -0.21568628 -0.46666667]]

  [[-0.39607844 -0.40392157 -0.56078434]
   [-0.38039216 -0.3882353  -0.54509807]
   [-0.35686275 -0.38039216 -0.5294118 ]
   ...
   [-0.19215687 -0.30980393 -0.45882353]
   [ 0.03529412 -0.07450981 -0.2784314 ]
   [-0.06666667 -0.22352941 -0.48235294]]

  [[-0.38039216 -0.3882353  -0.54509807]
   [-0.3647059  -0.37254903 -0.5294118 ]
   [-0.35686275 -0.38039216 -0.5294118 ]
   ...
   [ 0.08235294 -0.05098039 -0.19215687]
   [ 0.12156863 -0.01176471 -0.23137255]
   [ 0.01176471 -0.14509805 -0.40392157]]

  ...

  [[ 0.7882353   0.6784314   0.60784316]
   [ 0.7882353   0.6784314   0.60784316]
   [ 0.77254903  0.6627451   0.5921569 ]
   ...
   [ 0.35686275  0.06666667 -0.08235294]
   [ 0.654902    0.37254903  0.2       ]
   [ 0.6627451   0.38039216  0.1764706 ]]

  [[ 0.7882353   0.6784314   0.60784316]
   [ 0.7882353   0.6784314   0.60784316]
   [ 0.79607844  0.6862745   0.6156863 ]
   ...
   [ 0.42745098  0.13725491 -0.01176471]
   [ 0.42745098  0.14509805 -0.02745098]
   [ 0.7176471   0.43529412  0.23137255]]

  [[ 0.8117647   0.7019608   0.6313726 ]
   [ 0.8117647   0.7019608   0.6313726 ]
   [ 0.79607844  0.6862745   0.6156863 ]
   ...
   [ 0.8745098   0.58431375  0.43529412]
   [ 0.6627451   0.38039216  0.20784314]
   [ 0.60784316  0.3254902   0.12156863]]]


 ...


 [[[ 0.48235294  0.6156863   0.46666667]
   [ 0.28627452  0.5137255  -0.12941177]
   [ 0.3019608   0.5372549  -0.04313726]
   ...
   [ 0.37254903  0.56078434 -0.09803922]
   [ 0.43529412  0.6        -0.12941177]
   [ 0.45882353  0.5764706  -0.13725491]]

  [[ 0.49019608  0.62352943  0.4745098 ]
   [ 0.3019608   0.5294118  -0.11372549]
   [ 0.31764707  0.5529412  -0.02745098]
   ...
   [ 0.38039216  0.5686275  -0.09019608]
   [ 0.44313726  0.60784316 -0.12156863]
   [ 0.4745098   0.5921569  -0.13725491]]

  [[ 0.5058824   0.6392157   0.49019608]
   [ 0.31764707  0.54509807 -0.09803922]
   [ 0.31764707  0.5529412  -0.02745098]
   ...
   [ 0.3882353   0.5764706  -0.08235294]
   [ 0.45882353  0.62352943 -0.10588235]
   [ 0.45882353  0.5764706  -0.15294118]]

  ...

  [[-0.73333335 -0.88235295 -0.92941177]
   [-0.8509804  -0.92156863 -0.99215686]
   [-0.88235295 -0.9843137  -1.        ]
   ...
   [-0.827451   -0.9372549  -1.        ]
   [-0.7882353  -0.8980392  -0.96862745]
   [-0.7411765  -0.8745098  -0.9529412 ]]

  [[-0.5372549  -0.7647059  -0.8117647 ]
   [-0.7882353  -0.90588236 -0.9607843 ]
   [-0.85882354 -0.9764706  -1.        ]
   ...
   [-0.7882353  -0.8980392  -0.96862745]
   [-0.77254903 -0.88235295 -0.9529412 ]
   [-0.56078434 -0.69411767 -0.77254903]]

  [[-0.18431373 -0.5529412  -0.6313726 ]
   [-0.60784316 -0.79607844 -0.8745098 ]
   [-0.79607844 -0.92941177 -0.9843137 ]
   ...
   [-0.77254903 -0.88235295 -0.9529412 ]
   [-0.73333335 -0.84313726 -0.9137255 ]
   [-0.64705884 -0.78039217 -0.85882354]]]


 [[[-0.5372549  -0.7490196  -0.8039216 ]
   [-0.48235294 -0.7411765  -0.78039217]
   [-0.49803922 -0.77254903 -0.8039216 ]
   ...
   [-0.8509804  -0.8666667  -0.85882354]
   [-0.8666667  -0.88235295 -0.8745098 ]
   [-0.94509804 -0.9607843  -0.9529412 ]]

  [[-0.5372549  -0.7490196  -0.8039216 ]
   [-0.48235294 -0.7411765  -0.78039217]
   [-0.49803922 -0.77254903 -0.8039216 ]
   ...
   [-0.8509804  -0.8666667  -0.85882354]
   [-0.8666667  -0.88235295 -0.8745098 ]
   [-0.94509804 -0.9607843  -0.9529412 ]]

  [[-0.54509807 -0.75686276 -0.8117647 ]
   [-0.4745098  -0.73333335 -0.77254903]
   [-0.49019608 -0.7647059  -0.79607844]
   ...
   [-0.84313726 -0.85882354 -0.8509804 ]
   [-0.85882354 -0.8745098  -0.8666667 ]
   [-0.94509804 -0.9607843  -0.9529412 ]]

  ...

  [[ 0.05882353  0.14509805  0.16078432]
   [ 0.38039216  0.49019608  0.5137255 ]
   [ 0.34117648  0.48235294  0.49803922]
   ...
   [-0.6        -0.6        -0.6       ]
   [-0.827451   -0.827451   -0.827451  ]
   [-0.84313726 -0.84313726 -0.84313726]]

  [[ 0.4745098   0.62352943  0.654902  ]
   [ 0.2784314   0.42745098  0.45882353]
   [ 0.46666667  0.6313726   0.654902  ]
   ...
   [-0.5137255  -0.5137255  -0.5137255 ]
   [-0.85882354 -0.85882354 -0.85882354]
   [-0.8352941  -0.8352941  -0.8352941 ]]

  [[ 0.42745098  0.67058825  0.6862745 ]
   [ 0.5686275   0.7882353   0.8117647 ]
   [ 0.4117647   0.6         0.6156863 ]
   ...
   [-0.4117647  -0.4117647  -0.4117647 ]
   [-0.81960785 -0.81960785 -0.81960785]
   [-0.8352941  -0.8352941  -0.8352941 ]]]


 [[[ 0.7490196   0.9529412   0.88235295]
   [ 0.70980394  0.9137255   0.84313726]
   [ 0.7254902   0.92941177  0.85882354]
   ...
   [ 1.          1.          1.        ]
   [ 1.          1.          1.        ]
   [ 1.          1.          1.        ]]

  [[ 0.7411765   0.94509804  0.8745098 ]
   [ 0.7490196   0.9529412   0.88235295]
   [ 0.7490196   0.9529412   0.88235295]
   ...
   [ 1.          1.          1.        ]
   [ 1.          1.          1.        ]
   [ 1.          1.          1.        ]]

  [[ 0.77254903  0.9764706   0.90588236]
   [ 0.7411765   0.94509804  0.8745098 ]
   [ 0.7490196   0.9529412   0.88235295]
   ...
   [ 1.          1.          1.        ]
   [ 1.          1.          1.        ]
   [ 1.          1.          1.        ]]

  ...

  [[-0.8901961  -0.85882354 -0.8666667 ]
   [-0.8901961  -0.85882354 -0.8666667 ]
   [-0.90588236 -0.8745098  -0.88235295]
   ...
   [-0.8117647  -0.78039217 -0.77254903]
   [-0.8117647  -0.78039217 -0.77254903]
   [-0.827451   -0.79607844 -0.77254903]]

  [[-0.8745098  -0.84313726 -0.8509804 ]
   [-0.8745098  -0.84313726 -0.8509804 ]
   [-0.8901961  -0.85882354 -0.8666667 ]
   ...
   [-0.81960785 -0.7882353  -0.78039217]
   [-0.81960785 -0.7882353  -0.78039217]
   [-0.79607844 -0.7647059  -0.7411765 ]]

  [[-0.8901961  -0.85882354 -0.8666667 ]
   [-0.92156863 -0.8901961  -0.8980392 ]
   [-0.9372549  -0.90588236 -0.9137255 ]
   ...
   [-0.81960785 -0.7882353  -0.78039217]
   [-0.827451   -0.79607844 -0.7882353 ]
   [-0.81960785 -0.7882353  -0.7647059 ]]]]
In [5]:
%matplotlib inline
# [自習] データ、ラベル, ラベル2(省略可)を表示する関数
import matplotlib.pyplot as plt
import numpy as np

def showImages(xs, ts=[], ts2=[], rows=-1, cols=-1, w=3.2, h=3.2):
    N = len(xs)
    if rows < 0: rows = 1
    if cols < 0: cols = (N + rows - 1) // rows
    scale = 1
    if len(ts) > 0:
        scale += .25
    if len(ts2) > 0:
        scale += .25
    fig, ax = plt.subplots(rows, cols, figsize=(w*cols, w*rows*scale))
    idx = 0
    for row in range(rows):
        for col in range(cols):
            if rows == 1 and cols == 1:
                axis = ax
            elif cols == 1:
                axis = ax[col]
            elif rows == 1:
                axis = ax[row]
            else:
                axis = ax[row][col]
 
            if idx < N:
                axis.imshow(xs[idx])
                if idx < len(ts):
                    axis.text(0.5, -0.25, f'{ts[idx]}', fontsize=12, ha='center', transform=axis.transAxes)
                if idx < len(ts2):
                    axis.text(0.5, -0.25, f'{ts2[idx]}', fontsize=12, ha='center', transform=axis.transAxes)                    
            axis.axis('off')
            idx += 1
    plt.show()
In [6]:
# [自習] データをいくつかランダムに選んで表示してみる。
N_LINE = 5
N_COL = 5

showImages((x[0][:(N_LINE*N_COL)]+1)/2, [], [], N_LINE, N_COL)
In [7]:
# [自習] x_train[:] の0次元目はサイズ2である。x_train[0] と x_train[1]は同じものである。
# [自習] データをいくつかランダムに選んで表示してみる。
N_LINE = 5
N_COL = 5

showImages((x[1][:(N_LINE*N_COL)]+1)/2, [], [], N_LINE, N_COL)