3.2.5 オートエンコーダの解析

オートエンコーダの潜在空間内で画像がどのように表現されているかを調べる。

03_02_autoencoder_analysis.ipynb を参考にしながら、進める。

モデルのエンコーダに新しい画像を入力として与えて、その結果の2次元画像を散布図に描画してみる。 図3-8 を見ると、似たように見える数字が潜在空間の同じ部分に集まっているように見える。 訓練中に数字のラベルを全く与えていないのに、このような結果を得るということは、非常に興味深い。

  1. 図3-8 は、点 $(0,0)$ について対照ではなく、有界でもない。
  2. とても狭い領域で表現されている数字もあれば、広い範囲で表現されている数字もある。
  3. 色と色の間に、点をほとんど含まない大きな隙間がある。

ライブラリ

In [1]:
# DGL_code/utils/loaders.py
# gdl_ch03_01 と同じ
from tensorflow.keras.datasets import mnist

def load_mnist():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.astype('float32') / 255.0
    x_train = x_train.reshape(x_train.shape + (1,))
    x_test = x_test.astype('float32') / 255.0
    x_test = x_test.reshape(x_test.shape + (1,))
    return (x_train, y_train), (x_test, y_test)# DGL_code/utils/loaders.py
In [2]:
# DGL_code/utils/loaders.py
import os
import pickle

def load_model(model_class, folder):
    with open(os.path.join(folder, 'params.pkl'), 'rb') as f:
        params = pickle.load(f)
    model = model_class(*params)
    model.load_weights(os.path.join(folder, 'weights/weights.h5'))
    return model 
In [3]:
# DGL_code/models/AE.py
# [自分へのメモ] Autoencoder クラスの中の plot_mode() 関数で、フォルダを作成するコードを追加した。
# gdl_ch03_01 と同じ

from tensorflow.keras.layers import Input, Conv2D, LeakyReLU, BatchNormalization, Dropout, Flatten, Dense, Reshape, Conv2DTranspose, Activation
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import ModelCheckpoint

import os
import pickle


class Autoencoder():
    def __init__(self, 
                input_dim,
                encoder_conv_filters,
                encoder_conv_kernel_size,
                encoder_conv_strides,
                decoder_conv_t_filters,
                decoder_conv_t_kernel_size,
                decoder_conv_t_strides,
                z_dim,
                use_batch_norm = False,
                use_dropout = False
                ):
            self.name = 'autoencoder'
            self.input_dim = input_dim
            self.encoder_conv_filters = encoder_conv_filters
            self.encoder_conv_kernel_size = encoder_conv_kernel_size
            self.encoder_conv_strides = encoder_conv_strides
            self.decoder_conv_t_filters = decoder_conv_t_filters
            self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size
            self.decoder_conv_t_strides = decoder_conv_t_strides
            self.z_dim = z_dim
            
            self.use_batch_norm = use_batch_norm
            self.use_dropout = use_dropout
            
            self.n_layers_encoder = len(encoder_conv_filters)
            self.n_layers_decoder = len(decoder_conv_t_filters)
            
            self._build()
 

    def _build(self):
        ### THE ENCODER
        encoder_input = Input(shape=self.input_dim, name='encoder_input')
        x = encoder_input
        
        for i in range(self.n_layers_encoder):
            conv_layer =Conv2D(
                filters = self.encoder_conv_filters[i],
                kernel_size = self.encoder_conv_kernel_size[i],
                strides = self.encoder_conv_strides[i],
                padding  = 'same',
                name = 'encoder_conv_' + str(i)
            )
            x = conv_layer(x)
            x = LeakyReLU()(x)
            if self.use_batch_norm:
                x = BatchNormalization()(x)
            if self.use_dropout:
                x = Dropout(rate = 0.25)(x)
        
        shape_before_flattening = K.int_shape(x)[1:]
        
        x = Flatten()(x)
        encoder_output = Dense(self.z_dim, name='encoder_output')(x)
        
        self.encoder = Model(encoder_input, encoder_output)
        
        ### THE DECODER
        decoder_input = Input(shape=(self.z_dim,), name='decoder_input')
        x = Dense(np.prod(shape_before_flattening))(decoder_input)
        x = Reshape(shape_before_flattening)(x)
        
        for i in range(self.n_layers_decoder):
            conv_t_layer =   Conv2DTranspose(
                filters = self.decoder_conv_t_filters[i],
                kernel_size = self.decoder_conv_t_kernel_size[i],
                strides = self.decoder_conv_t_strides[i],
                padding = 'same',
                name = 'decoder_conv_t_' + str(i)
            )
            x = conv_t_layer(x)
            
            if i < self.n_layers_decoder - 1:
                x = LeakyReLU()(x)
                if self.use_batch_norm:
                    x = BatchNormalization()(x)
                if self.use_dropout:
                    x = Dropout(rate=0.25)(x)
            else:
                x = Activation('sigmoid')(x)
       
        decoder_output = x
        self.decoder = Model(decoder_input, decoder_output)
        
        ### THE FULL AUTOENCODER
        model_input = encoder_input
        model_output = self.decoder(encoder_output)
        
        self.model = Model(model_input, model_output)

        
    def compile(self, learning_rate):
        self.learning_rate = learning_rate
        optimizer = Adam(lr=learning_rate)
        def r_loss(y_true, y_pred):
            return K.mean(K.square(y_true - y_pred), axis = [1,2,3])
        self.model.compile(optimizer=optimizer, loss = r_loss)
        
        
    def save(self, folder):
        if not os.path.exists(folder):
            os.makedirs(folder)
            os.makedirs(os.path.join(folder, 'viz'))
            os.makedirs(os.path.join(folder, 'weights'))
            os.makedirs(os.path.join(folder, 'images'))
            
        with open(os.path.join(folder, 'params.pkl'), 'wb') as f:
            pickle.dump([
                self.input_dim,
                self.encoder_conv_filters,
                self.encoder_conv_kernel_size,
                self.encoder_conv_strides,
                self.decoder_conv_t_filters,
                self.decoder_conv_t_kernel_size,
                self.decoder_conv_t_strides,
                self.z_dim,
                self.use_batch_norm,
                self.use_dropout
            ], f)
            
        self.plot_model(folder)
        
        
    def plot_model(self, run_folder):
        ### start of section added by nitta
        path = os.path.join(run_folder, 'viz')
        if not os.path.exists(path):
            os.makedirs(path)
        ### end of section added by nitta
        plot_model(self.model, to_file=os.path.join(run_folder, 'viz/model.png'), show_shapes=True, show_layer_names=True)
        plot_model(self.encoder, to_file=os.path.join(run_folder, 'viz/encoder.png'), show_shapes=True, show_layer_names=True)
        plot_model(self.decoder, to_file=os.path.join(run_folder, 'viz/decoder.png'), show_shapes=True, show_layer_names=True)

        
    def load_weights(self, filepath):
        self.model.load_weights(filepath)
        
        
    def train(self, x_train, batch_size, epochs, run_folder, print_every_n_batches=100, initial_epoch=0, lr_decay=1):
        custom_callback = CustomCallback(run_folder, print_every_n_batches, initial_epoch, self)
        lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1)
        checkpoint2 = ModelCheckpoint(os.path.join(run_folder, 'weights/weights.h5'), save_weights_only=True, verbose=1)
        callbacks_list = [checkpoint2, custom_callback, lr_sched]
        self.model.fit(
            x_train,
            x_train,
            batch_size = batch_size,
            shuffle = True,
            epochs = epochs,
            initial_epoch = initial_epoch,
            callbacks = callbacks_list)
In [4]:
# GDL_code/utils/callbacks.py
# gdl_ch03_01 と同じ

import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import Callback, LearningRateScheduler

class CustomCallback(Callback):
    def __init__(self, run_folder, print_every_n_batches, initial_epoch, vae):
        self.run_folder = run_folder
        self.print_every_n_batches = print_every_n_batches
        self.epoch = initial_epoch
        self.vae = vae
        
        
    def on_train_batch_end(self, batch, logs={}):
        if batch % self.print_every_n_batches == 0:
            z_new = np.random.normal(size=(1,self.vae.z_dim))
            reconst = self.vae.decoder.predict(np.array(z_new))[0].squeeze()
            
            filepath = os.path.join(self.run_folder, 'images', 'img_'+str(self.epoch).zfill(3)+'_'+str(batch)+'.jpg')
            if len(reconst.shape) == 2:
                plt.imsave(filepath, reconst, cmap='gray_r')
            else:
                plt.imsave(filepath, reconst)
        
        
    def on_epoch_begin(self, epoch, logs={}):
        self.epoch += 1
        
        
def step_decay_schedule(initial_lr, decay_factor=0.5, step_size=1):
    '''
    Wrapper function to create a LearningRateScheduler with step decay schedule.
    '''
    def schedule(epoch):
        new_lr = initial_lr * (decay_factor ** np.floor(epoch/step_size))
        return new_lr
    return LearningRateScheduler(schedule)
In [5]:
# run params
# [自分へのメモ] os.mkdir() 関数はpathの途中のフォルダが存在しないとエラーとなるので os.makedirs() 関数に変更した。

SECTION = 'vae'
RUN_ID = '0001'
DATA_NAME = 'digits'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

ニューラルネットワークのモデルを重みつきでロードする。

In [6]:
AE = load_model(Autoencoder, RUN_FOLDER)

データをロードする

In [7]:
(x_train, y_train), (x_test, y_test) = load_mnist()

画像をコード化してからデコードしてみる

In [8]:
n_to_show = 10
example_idx = np.random.choice(range(len(x_test)), n_to_show)
example_images = x_test[example_idx]

z_points = AE.encoder.predict(example_images)
reconst_images = AE.decoder.predict(z_points)
In [9]:
%matplotlib inline
import matplotlib.pyplot as plt

VSKIP=1
fig, ax = plt.subplots(2, n_to_show, figsize=(2.8 * n_to_show, 2.8 * (2+VSKIP)))   # image.shape: 28x28x1
plt.subplots_adjust(hspace=VSKIP)

for i in range(n_to_show):
    img = example_images[i].squeeze()   # original image
    ax[0][i].imshow(img,cmap='gray_r') 
    ax[0][i].axis('off')
    
    ax[0][i].text(0.5, -0.35, str(np.round(z_points[i], 1)), fontsize=16, ha='center', transform=ax[0][i].transAxes)
    
    img2 = reconst_images[i].squeeze()    # reconstructed image
    ax[1][i].imshow(img2, cmap='gray_r')
    ax[1][i].axis('off')

Mr N. Coder's wall

画像を、潜在空間の2次元座標にエンコードして、点として描画する。 描画する画像はtest用画像から 5000 枚の画像をランダムに選ぶ。

In [10]:
n_to_show = 5000

example_idx = np.random.choice(range(len(x_test)), n_to_show)
example_images = x_test[example_idx]
example_labels = y_test[example_idx]

z_points = AE.encoder.predict(example_images)

min_x = min(z_points[:, 0])
max_x = max(z_points[:, 0])
min_y = min(z_points[:, 1])
max_y = max(z_points[:,1])
In [11]:
%matplotlib inline
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.scatter(z_points[:, 0], z_points[:, 1], c='black', alpha=0.5, s=2)

plt.show()

正解ラベルに応じて点の色を変化させると次のようになる。

In [12]:
%matplotlib inline
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1, figsize=(12, 12))
map = ax.scatter(z_points[:, 0], z_points[:, 1], c=example_labels, cmap='rainbow', alpha=0.5, s=2)

plt.colorbar(map)   # plt.colorbar() だとエラーになるので注意。pltに対して描画していない場合は、colorbar()の引数にMappableを指定する必要がある。
plt.show()

The new generated art exhibition

上でランダムに選択したテスト用画像がエンコードされた座標の範囲内で、ランダムに点を $10 \times 3 = 30$ 個生成して、画像にデコードする。 ランダムに生成した点は赤い点として表示する。

In [13]:
import numpy as np

table_row = 10    # 表の横方向サイズ
table_line = 3   #表の縦方向サイズ

x = np.random.uniform(min_x, max_x, size=table_line * table_row)
y = np.random.uniform(min_y, max_y, size=table_line * table_row)
z_grid = np.array(list(zip(x,y)))    # (x, y) : 2D coordinates
reconst = AE.decoder.predict(z_grid)
In [14]:
%matplotlib inline
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1,1,figsize=(8, 8))
ax.scatter(z_points[:, 0], z_points[:, 1], c='black', alpha=0.5, s=2)

ax.scatter(z_grid[:, 0], z_grid[:, 1], c='red', alpha=1, s=20)

plt.show()

生成し点の座標と、デコードして得られた画像を table_line (=3) 行 table_row (=10) 列の表として図示する。

In [15]:
%matplotlib inline
import matplotlib.pyplot as plt

VSKIP=0.5   # vertical space between subplots

fig, ax = plt.subplots(table_line, table_row, figsize=(2.8 * table_row, 2.8 * table_line * (1+VSKIP)))
plt.subplots_adjust(hspace = VSKIP)
                       
for y in range(table_line):
    for x in range(table_row):
        idx = table_row * y + x
        img = reconst[idx].squeeze()
        ax[y][x].imshow(img, cmap='gray')
        ax[y][x].text(0.5, -0.35, str(np.round(z_grid[idx], 1)), fontsize=16, ha='center', transform=ax[y][x].transAxes)
        ax[y][x].axis('off')
        
plt.show()

潜在空間をグリッドに区切って、各座標からどのような画像が生成(デコード)されるか調べる

$20 \times 20 $ のグリッドから、画像を生成する。 生成した画像は 20 行 20列の表として表示する。

In [16]:
import numpy as np

n_grid = 20

x = np.linspace(min_x, max_x, n_grid)
y = np.linspace(min_y, max_y, n_grid)

xv, yv = np.meshgrid(x, y)
xv = xv.flatten()
yv = yv.flatten()
z_grid2 = np.array(list(zip(xv, yv)))

reconst2 = AE.decoder.predict(z_grid2)
In [17]:
%matplotlib inline
import matplotlib.pyplot as plt

fig, ax = plt.subplots(n_grid, n_grid, figsize=(n_grid, n_grid))
for i in range(len(reconst2)):
    img = reconst2[i].squeeze()
    line = i // n_grid
    row = i % n_grid
    ax[line][row].imshow(img, cmap='gray')
    ax[line][row].axis('off')
    
plt.show()