Updated 21/Nov/2021 by Yoshihisa Nitta  

Variational Auto Encoder Training for CelebA dataset with Tensorflow 2 on Google Colab

Train Variational Auto Encoder on CelebA dataset. See VAE_MNIST_Train.ipynb for a description of Variational Auto Encodr.

CelebA データセットに対して Variational Auto Encoder をGoogle Colab 上の Tensorflow 2 で学習する

CelebA データセットに対して変分オートエンコーダを学習させる。 Variational Auto Encoder の説明は VAE_MNIST_Train.ipynb を参照すること。

In [1]:
#! pip install tensorflow==2.7.0
In [2]:
%tensorflow_version 2.x

import tensorflow as tf
print(tf.__version__)
2.7.0

Check the Google Colab runtime environment

Google Colab 実行環境を調べる

In [3]:
! nvidia-smi
! cat /proc/cpuinfo
! cat /etc/issue
! free -h
Sun Nov 21 13:55:57 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 79
model name	: Intel(R) Xeon(R) CPU @ 2.20GHz
stepping	: 0
microcode	: 0x1
cpu MHz		: 2199.998
cache size	: 56320 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 0
initial apicid	: 0
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4399.99
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 bits virtual
power management:

processor	: 1
vendor_id	: GenuineIntel
cpu family	: 6
model		: 79
model name	: Intel(R) Xeon(R) CPU @ 2.20GHz
stepping	: 0
microcode	: 0x1
cpu MHz		: 2199.998
cache size	: 56320 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 1
initial apicid	: 1
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4399.99
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 bits virtual
power management:

Ubuntu 18.04.5 LTS \n \l

              total        used        free      shared  buff/cache   available
Mem:            12G        867M        3.0G        1.2M        8.9G         11G
Swap:            0B          0B          0B

Mount Google Drive from Google Colab

Google Colab から GoogleDrive をマウントする

In [4]:
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
In [5]:
! ls /content/drive
MyDrive  Shareddrives

Download source file from Google Drive or nw.tsuda.ac.jp

Basically, gdown from Google Drive. Download from nw.tsuda.ac.jp above only if the specifications of Google Drive change and you cannot download from Google Drive.

Google Drive または nw.tsuda.ac.jp からファイルをダウンロードする

基本的に Google Drive から gdown してください。 Google Drive の仕様が変わってダウンロードができない場合にのみ、nw.tsuda.ac.jp からダウンロードしてください。

In [7]:
# Download source file
nw_path = './nw'
! rm -rf {nw_path}
! mkdir -p {nw_path}

if True:   # from Google Drive
    url_model =  'https://drive.google.com/uc?id=1ZCihR7JkMOity4wCr66ZCp-3ZOlfwwo3'
    ! (cd {nw_path}; gdown {url_model})
else:      # from nw.tsuda.ac.jp
    URL_NW = 'https://nw.tsuda.ac.jp/lec/GoogleColab/pub'
    url_model = f'{URL_NW}/models/VariationalAutoEncoder.py'
    ! wget -nd {url_model} -P {nw_path}
Downloading...
From: https://drive.google.com/uc?id=1ZCihR7JkMOity4wCr66ZCp-3ZOlfwwo3
To: /content/nw/VariationalAutoEncoder.py
100% 18.7k/18.7k [00:00<00:00, 29.1MB/s]
In [8]:
! cat {nw_path}/VariationalAutoEncoder.py
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import os
import pickle
import datetime

class Sampling(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, inputs):
        mu, log_var = inputs
        epsilon = tf.keras.backend.random_normal(shape=tf.keras.backend.shape(mu), mean=0., stddev=1.)
        return mu + tf.keras.backend.exp(log_var / 2) * epsilon


class VAEModel(tf.keras.models.Model):
    def __init__(self, encoder, decoder, r_loss_factor, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.r_loss_factor = r_loss_factor


    @tf.function
    def loss_fn(self, x):
        z_mean, z_log_var, z = self.encoder(x)
        reconstruction = self.decoder(z)
        reconstruction_loss = tf.reduce_mean(
            tf.square(x - reconstruction), axis=[1,2,3]
        ) * self.r_loss_factor
        kl_loss = tf.reduce_sum(
            1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var),
            axis = 1
        ) * (-0.5)
        total_loss = reconstruction_loss + kl_loss
        return total_loss, reconstruction_loss, kl_loss


    @tf.function
    def compute_loss_and_grads(self, x):
        with tf.GradientTape() as tape:
            total_loss, reconstruction_loss, kl_loss = self.loss_fn(x)
        grads = tape.gradient(total_loss, self.trainable_weights)
        return total_loss, reconstruction_loss, kl_loss, grads


    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        total_loss, reconstruction_loss, kl_loss, grads = self.compute_loss_and_grads(data)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": tf.math.reduce_mean(total_loss),
            "reconstruction_loss": tf.math.reduce_mean(reconstruction_loss),
            "kl_loss": tf.math.reduce_mean(kl_loss),
        }

    def call(self,inputs):
        _, _, z = self.encoder(inputs)
        return self.decoder(z)


class VariationalAutoEncoder():
    def __init__(self, 
                 input_dim,
                 encoder_conv_filters,
                 encoder_conv_kernel_size,
                 encoder_conv_strides,
                 decoder_conv_t_filters,
                 decoder_conv_t_kernel_size,
                 decoder_conv_t_strides,
                 z_dim,
                 r_loss_factor,   ### added
                 use_batch_norm = False,
                 use_dropout = False,
                 epoch = 0
                ):
        self.name = 'variational_autoencoder'
        self.input_dim = input_dim
        self.encoder_conv_filters = encoder_conv_filters
        self.encoder_conv_kernel_size = encoder_conv_kernel_size
        self.encoder_conv_strides = encoder_conv_strides
        self.decoder_conv_t_filters = decoder_conv_t_filters
        self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size
        self.decoder_conv_t_strides = decoder_conv_t_strides
        self.z_dim = z_dim
        self.r_loss_factor = r_loss_factor   ### added
            
        self.use_batch_norm = use_batch_norm
        self.use_dropout = use_dropout

        self.epoch = epoch
            
        self.n_layers_encoder = len(encoder_conv_filters)
        self.n_layers_decoder = len(decoder_conv_t_filters)
            
        self._build()
 

    def _build(self):
        ### THE ENCODER
        encoder_input = tf.keras.layers.Input(shape=self.input_dim, name='encoder_input')
        x = encoder_input
        
        for i in range(self.n_layers_encoder):
            x = conv_layer = tf.keras.layers.Conv2D(
                filters = self.encoder_conv_filters[i],
                kernel_size = self.encoder_conv_kernel_size[i],
                strides = self.encoder_conv_strides[i],
                padding  = 'same',
                name = 'encoder_conv_' + str(i)
            )(x)

            if self.use_batch_norm:                                ### The order of layers is opposite to AutoEncoder
                x = tf.keras.layers.BatchNormalization()(x)        ###   AE: LeakyReLU -> BatchNorm
            x = tf.keras.layers.LeakyReLU()(x)                     ###   VAE: BatchNorm -> LeakyReLU
            
            if self.use_dropout:
                x = tf.keras.layers.Dropout(rate = 0.25)(x)
        
        shape_before_flattening = tf.keras.backend.int_shape(x)[1:]
        
        x = tf.keras.layers.Flatten()(x)
        
        self.mu = tf.keras.layers.Dense(self.z_dim, name='mu')(x)
        self.log_var = tf.keras.layers.Dense(self.z_dim, name='log_var')(x) 
        self.z = Sampling(name='encoder_output')([self.mu, self.log_var])
        
        self.encoder = tf.keras.models.Model(encoder_input, [self.mu, self.log_var, self.z], name='encoder')
        
        
        ### THE DECODER
        decoder_input = tf.keras.layers.Input(shape=(self.z_dim,), name='decoder_input')
        x = decoder_input
        x = tf.keras.layers.Dense(np.prod(shape_before_flattening))(x)
        x = tf.keras.layers.Reshape(shape_before_flattening)(x)
        
        for i in range(self.n_layers_decoder):
            x = conv_t_layer =   tf.keras.layers.Conv2DTranspose(
                filters = self.decoder_conv_t_filters[i],
                kernel_size = self.decoder_conv_t_kernel_size[i],
                strides = self.decoder_conv_t_strides[i],
                padding = 'same',
                name = 'decoder_conv_t_' + str(i)
            )(x)
            
            if i < self.n_layers_decoder - 1:
                if self.use_batch_norm:                           ### The order of layers is opposite to AutoEncoder
                    x = tf.keras.layers.BatchNormalization()(x)   ###     AE: LeakyReLU -> BatchNorm
                x = tf.keras.layers.LeakyReLU()(x)                ###      VAE: BatchNorm -> LeakyReLU                
                if self.use_dropout:
                    x = tf.keras.layers.Dropout(rate=0.25)(x)
            else:
                x = tf.keras.layers.Activation('sigmoid')(x)
       
        decoder_output = x
        self.decoder = tf.keras.models.Model(decoder_input, decoder_output, name='decoder')  ### added (name)
        
        ### THE FULL AUTOENCODER
        self.model = VAEModel(self.encoder, self.decoder, self.r_loss_factor)
        
        
    def save(self, folder):
        self.save_params(os.path.join(folder, 'params.pkl'))
        self.save_weights(folder)


    @staticmethod
    def load(folder, epoch=None):  # VariationalAutoEncoder.load(folder)
        params = VariationalAutoEncoder.load_params(os.path.join(folder, 'params.pkl'))
        VAE = VariationalAutoEncoder(*params)
        if epoch is None:
            VAE.load_weights(folder)
        else:
            VAE.load_weights(folder, epoch-1)
            VAE.epoch = epoch
        return VAE

        
    def save_params(self, filepath):
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        with open(filepath, 'wb') as f:
            pickle.dump([
                self.input_dim,
                self.encoder_conv_filters,
                self.encoder_conv_kernel_size,
                self.encoder_conv_strides,
                self.decoder_conv_t_filters,
                self.decoder_conv_t_kernel_size,
                self.decoder_conv_t_strides,
                self.z_dim,
                self.r_loss_factor,
                self.use_batch_norm,
                self.use_dropout,
                self.epoch
            ], f)


    @staticmethod
    def load_params(filepath):
        with open(filepath, 'rb') as f:
            params = pickle.load(f)
        return params


    def save_weights(self, folder, epoch=None):
        if epoch is None:
            self.save_model_weights(self.encoder, os.path.join(folder, f'weights/encoder-weights.h5'))
            self.save_model_weights(self.decoder, os.path.join(folder, f'weights/decoder-weights.h5'))
        else:
            self.save_model_weights(self.encoder, os.path.join(folder, f'weights/encoder-weights_{epoch}.h5'))
            self.save_model_weights(self.decoder, os.path.join(folder, f'weights/decoder-weights_{epoch}.h5'))


    def save_model_weights(self, model, filepath):
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        model.save_weights(filepath)


    def load_weights(self, folder, epoch=None):
        if epoch is None:
            self.encoder.load_weights(os.path.join(folder, f'weights/encoder-weights.h5'))
            self.decoder.load_weights(os.path.join(folder, f'weights/decoder-weights.h5'))
        else:
            self.encoder.load_weights(os.path.join(folder, f'weights/encoder-weights_{epoch}.h5'))
            self.decoder.load_weights(os.path.join(folder, f'weights/decoder-weights_{epoch}.h5'))


    def save_images(self, imgs, filepath):
        z_mean, z_log_var, z = self.encoder.predict(imgs)
        reconst_imgs = self.decoder.predict(z)
        txts = [ f'{p[0]:.3f}, {p[1]:.3f}' for p in z ]
        AutoEncoder.showImages(imgs, reconst_imgs, txts, 1.4, 1.4, 0.5, filepath)
      

    def compile(self, learning_rate):
        self.learning_rate = learning_rate
        optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
        self.model.compile(optimizer=optimizer)     # CAUTION!!!: loss(y_true, y_pred) function is not specified.
        
        
    def train_with_fit(
            self,
            x_train,
            batch_size,
            epochs,
            run_folder='run/'
    ):
        history = self.model.fit(
            x_train,
            x_train,
            batch_size = batch_size,
            shuffle=True,
            initial_epoch = self.epoch,
            epochs = epochs
        )
        if (self.epoch < epochs):
            self.epoch = epochs

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(run_folder, self.epoch-1)
        
        return history


    def train_generator_with_fit(
            self,
            data_flow,
            epochs,
            run_folder='run/'
    ):
        history = self.model.fit(
            data_flow,
            initial_epoch = self.epoch,
            epochs = epochs
        )
        if (self.epoch < epochs):
            self.epoch = epochs

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(run_folder, self.epoch-1)
        
        return history


    def train_tf(
            self,
            x_train,
            batch_size = 32,
            epochs = 10,
            shuffle = False,
            run_folder = 'run/',
            optimizer = None,
            save_epoch_interval = 100,
            validation_data = None
    ):
        start_time = datetime.datetime.now()
        steps = x_train.shape[0] // batch_size

        total_losses = []
        reconstruction_losses = []
        kl_losses = []

        val_total_losses = []
        val_reconstruction_losses = []
        val_kl_losses = []

        for epoch in range(self.epoch, epochs):
            epoch_loss = 0
            indices = tf.range(x_train.shape[0], dtype=tf.int32)
            if shuffle:
                indices = tf.random.shuffle(indices)
            x_ = x_train[indices]

            step_total_losses = []
            step_reconstruction_losses = []
            step_kl_losses = []
            for step in range(steps):
                start = batch_size * step
                end = start + batch_size

                total_loss, reconstruction_loss, kl_loss, grads = self.model.compute_loss_and_grads(x_[start:end])
                optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
                
                step_total_losses.append(np.mean(total_loss))
                step_reconstruction_losses.append(np.mean(reconstruction_loss))
                step_kl_losses.append(np.mean(kl_loss))
            
            epoch_total_loss = np.mean(step_total_losses)
            epoch_reconstruction_loss = np.mean(step_reconstruction_losses)
            epoch_kl_loss = np.mean(step_kl_losses)

            total_losses.append(epoch_total_loss)
            reconstruction_losses.append(epoch_reconstruction_loss)
            kl_losses.append(epoch_kl_loss)

            val_str = ''
            if not validation_data is None:
                x_val = validation_data
                tl, rl, kl = self.model.loss_fn(x_val)
                val_tl = np.mean(tl)
                val_rl = np.mean(rl)
                val_kl = np.mean(kl)
                val_total_losses.append(val_tl)
                val_reconstruction_losses.append(val_rl)
                val_kl_losses.append(val_kl)
                val_str = f'val loss total {val_tl:.3f} reconstruction {val_rl:.3f} kl {val_kl:.3f} '

            if (epoch+1) % save_epoch_interval == 0 and run_folder != None:
                self.save(run_folder)
                self.save_weights(run_folder, self.epoch)

            elapsed_time = datetime.datetime.now() - start_time
            print(f'{epoch+1}/{epochs} {steps} loss: total {epoch_total_loss:.3f} reconstruction {epoch_reconstruction_loss:.3f} kl {epoch_kl_loss:.3f} {val_str}{elapsed_time}')

            self.epoch += 1

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(run_folder, self.epoch-1)

        dic = { 'loss' : total_losses, 'reconstruction_loss' : reconstruction_losses, 'kl_loss' : kl_losses }
        if not validation_data is None:
            dic['val_loss'] = val_total_losses
            dic['val_reconstruction_loss'] = val_reconstruction_losses
            dic['val_kl_loss'] = val_kl_losses

        return dic
            

    def train_tf_generator(
            self,
            data_flow,
            epochs = 10,
            run_folder = 'run/',
            optimizer = None,
            save_epoch_interval = 100,
            validation_data_flow = None
    ):
        start_time = datetime.datetime.now()
        steps = len(data_flow)

        total_losses = []
        reconstruction_losses = []
        kl_losses = []

        val_total_losses = []
        val_reconstruction_losses = []
        val_kl_losses = []

        for epoch in range(self.epoch, epochs):
            epoch_loss = 0

            step_total_losses = []
            step_reconstruction_losses = []
            step_kl_losses = []

            for step in range(steps):
                x, _ = next(data_flow)

                total_loss, reconstruction_loss, kl_loss, grads = self.model.compute_loss_and_grads(x)
                optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
                
                step_total_losses.append(np.mean(total_loss))
                step_reconstruction_losses.append(np.mean(reconstruction_loss))
                step_kl_losses.append(np.mean(kl_loss))
            
            epoch_total_loss = np.mean(step_total_losses)
            epoch_reconstruction_loss = np.mean(step_reconstruction_losses)
            epoch_kl_loss = np.mean(step_kl_losses)

            total_losses.append(epoch_total_loss)
            reconstruction_losses.append(epoch_reconstruction_loss)
            kl_losses.append(epoch_kl_loss)

            val_str = ''
            if not validation_data_flow is None:
                step_val_tl = []
                step_val_rl = []
                step_val_kl = []
                for i in range(len(validation_data_flow)):
                    x, _ = next(validation_data_flow)
                    tl, rl, kl = self.model.loss_fn(x)
                    step_val_tl.append(np.mean(tl))
                    step_val_rl.append(np.mean(rl))
                    step_val_kl.append(np.mean(kl))
                val_tl = np.mean(step_val_tl)
                val_rl = np.mean(step_val_rl)
                val_kl = np.mean(step_val_kl)
                val_total_losses.append(val_tl)
                val_reconstruction_losses.append(val_rl)
                val_kl_losses.append(val_kl)
                val_str = f'val loss total {val_tl:.3f} reconstruction {val_rl:.3f} kl {val_kl:.3f} '

            if (epoch+1) % save_epoch_interval == 0 and run_folder != None:
                self.save(run_folder)
                self.save_weights(run_folder, self.epoch)

            elapsed_time = datetime.datetime.now() - start_time
            print(f'{epoch+1}/{epochs} {steps} loss: total {epoch_total_loss:.3f} reconstruction {epoch_reconstruction_loss:.3f} kl {epoch_kl_loss:.3f} {val_str}{elapsed_time}')

            self.epoch += 1

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(run_folder, self.epoch-1)

        dic = { 'loss' : total_losses, 'reconstruction_loss' : reconstruction_losses, 'kl_loss' : kl_losses }
        if not validation_data_flow is None:
            dic['val_loss'] = val_total_losses
            dic['val_reconstruction_loss'] = val_reconstruction_losses
            dic['val_kl_loss'] = val_kl_losses

        return dic


    @staticmethod
    def showImages(imgs1, imgs2, txts, w, h, vskip=0.5, filepath=None):
        n = len(imgs1)
        fig, ax = plt.subplots(2, n, figsize=(w * n, (2+vskip) * h))
        for i in range(n):
            if n == 1:
                axis = ax[0]
            else:
                axis = ax[0][i]
            img = imgs1[i].squeeze()
            axis.imshow(img, cmap='gray_r')
            axis.axis('off')

            axis.text(0.5, -0.35, txts[i], fontsize=10, ha='center', transform=axis.transAxes)

            if n == 1:
                axis = ax[1]
            else:
                axis = ax[1][i]
            img2 = imgs2[i].squeeze()
            axis.imshow(img2, cmap='gray_r')
            axis.axis('off')

        if not filepath is None:
            dpath, fname = os.path.split(filepath)
            if dpath != '' and not os.path.exists(dpath):
                os.makedirs(dpath)
            fig.savefig(filepath, dpi=600)
            plt.close()
        else:
            plt.show()

    @staticmethod
    def plot_history(vals, labels):
        colors = ['red', 'blue', 'green', 'orange', 'black', 'pink']
        n = len(vals)
        fig, ax = plt.subplots(1, 1, figsize=(9,4))
        for i in range(n):
            ax.plot(vals[i], c=colors[i], label=labels[i])
        ax.legend(loc='upper right')
        ax.set_xlabel('epochs')
        # ax[0].set_ylabel('loss')
        
        plt.show()

Preparing CelebA dataset

Official WWW of CelebA dataset: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

Google Drive of CelebA dataset: https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg

img_align_celeba.zip mirrored on my Google Drive:
https://drive.google.com/uc?id=1LFKeoI-hb96jlV0K10dO1o04iQPBoFdx

CelebA データセットを用意する

CelebA データセットの公式ページ: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

CelebA データセットのGoogle Drive: https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg

自分の Google Drive 上にミラーした img_align_celeba.zip:
https://drive.google.com/uc?id=1LFKeoI-hb96jlV0K10dO1o04iQPBoFdx

In [9]:
# Download img_align_celeba.zip from GoogleDrive

MIRRORED_URL = 'https://drive.google.com/uc?id=1LFKeoI-hb96jlV0K10dO1o04iQPBoFdx'

! gdown {MIRRORED_URL}
Downloading...
From: https://drive.google.com/uc?id=1LFKeoI-hb96jlV0K10dO1o04iQPBoFdx
To: /content/img_align_celeba.zip
100% 1.44G/1.44G [00:06<00:00, 222MB/s]
In [72]:
! ls -l img_align_celeba.zip
-rw-r--r-- 1 root root 1443490838 Nov 21 13:56 img_align_celeba.zip
In [11]:
DATA_DIR = 'data'
DATA_SUBDIR = 'img_align_celeba'
In [12]:
! rm -rf {DATA_DIR}
! unzip -d {DATA_DIR} -q {DATA_SUBDIR}.zip
In [13]:
! ls -l {DATA_DIR}/{DATA_SUBDIR} | head
! ls {DATA_DIR}/{DATA_SUBDIR} | wc
total 1737936
-rw-r--r-- 1 root root 11440 Sep 28  2015 000001.jpg
-rw-r--r-- 1 root root  7448 Sep 28  2015 000002.jpg
-rw-r--r-- 1 root root  4253 Sep 28  2015 000003.jpg
-rw-r--r-- 1 root root 10747 Sep 28  2015 000004.jpg
-rw-r--r-- 1 root root  6351 Sep 28  2015 000005.jpg
-rw-r--r-- 1 root root  8073 Sep 28  2015 000006.jpg
-rw-r--r-- 1 root root  8203 Sep 28  2015 000007.jpg
-rw-r--r-- 1 root root  7725 Sep 28  2015 000008.jpg
-rw-r--r-- 1 root root  8641 Sep 28  2015 000009.jpg
 202599  202599 2228589

Check the CelebA dataset

CelebA データセットを確認する

In [14]:
# paths to all the image files.

import os
import glob
import numpy as np

all_file_paths = np.array(glob.glob(os.path.join(DATA_DIR, DATA_SUBDIR, '*.jpg')))
n_all_images = len(all_file_paths)

print(n_all_images)
202599
In [15]:
# slect some image files.

n_to_show = 10
selected_indices = np.random.choice(range(n_all_images), n_to_show)
selected_paths = all_file_paths[selected_indices]
In [16]:
# Display some images.
%matplotlib inline
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, n_to_show, figsize=(1.4 * n_to_show, 1.4))
for i, path in enumerate(selected_paths):
    img = tf.keras.preprocessing.image.load_img(path)
    ax[i].imshow(img)
    ax[i].axis('off')
plt.show()

Separate image files for train and test

画像ファイルを学習用とテスト用に分割する

In [17]:
TRAIN_DATA_DIR = 'train_data'
TEST_DATA_DIR = 'test_data'
In [18]:
import os

split = 0.05

indices = np.arange(n_all_images)
np.random.shuffle(indices)
train_indices = indices[: -int(n_all_images * split)]
test_indices = indices[-int(n_all_images * split):]

! rm -rf {TRAIN_DATA_DIR} {TEST_DATA_DIR}

dst=f'{TRAIN_DATA_DIR}/celeba'
if not os.path.exists(dst):
    os.makedirs(dst)
for idx in train_indices:
    path = all_file_paths[idx]
    dpath, fname = os.path.split(path)
    os.symlink(f'../../{path}', f'{dst}/{fname}')

dst=f'{TEST_DATA_DIR}/celeba'
if not os.path.exists(dst):
    os.makedirs(dst)
for idx in test_indices:
    path = all_file_paths[idx]
    dpath, fname = os.path.split(path)
    os.symlink(f'../../{path}', f'{dst}/{fname}')

Prepare ImageDataGenerator

flow_from_directory() requires to specify the parent directory of the directory where the image files are located.

ImageDataGenerator を用意する

flow_from_directory() では image files があるディレクトリの親ディレクトリを指定する必要がある。

In [19]:
INPUT_DIM = (128, 128, 3)
BATCH_SIZE = 32
In [20]:
data_gen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale = 1.0/255
    )

data_flow = data_gen.flow_from_directory(
    TRAIN_DATA_DIR,
    target_size = INPUT_DIM[:2],
    batch_size = BATCH_SIZE,
    shuffle=True,
    class_mode = 'input'
    )

val_data_flow = data_gen.flow_from_directory(
    TEST_DATA_DIR,
    target_size = INPUT_DIM[:2],
    batch_size = BATCH_SIZE,
    shuffle=True,
    class_mode = 'input'
    )
Found 192470 images belonging to 1 classes.
Found 10129 images belonging to 1 classes.
In [21]:
print(len(data_flow))
print(len(val_data_flow))
6015
317
In [22]:
# ImageDataGenerator.next() returns the same x and y when class_mode='input'
x, y = next(data_flow)
print(x[0].shape)

%matplotlib inline
import matplotlib.pyplot as plt

n_to_show = 10
fig, ax = plt.subplots(2, n_to_show, figsize=(1.4 * n_to_show, 1.4 * 2))
for i in range(n_to_show):
    ax[0][i].imshow(x[i])
    ax[0][i].axis('off')
    ax[1][i].imshow(y[i])
    ax[1][i].axis('off')
plt.show()
(128, 128, 3)

Define the Neural Network Model

ニューラルネットワーク・モデルを定義する

In [23]:
from nw.VariationalAutoEncoder import VariationalAutoEncoder

vae = VariationalAutoEncoder(
    input_dim = INPUT_DIM,
    encoder_conv_filters = [ 32, 64, 64, 64 ],
    encoder_conv_kernel_size = [ 3, 3, 3, 3 ],
    encoder_conv_strides = [ 2, 2, 2, 2 ],
    decoder_conv_t_filters = [ 64, 64, 32, 3 ],
    decoder_conv_t_kernel_size = [ 3, 3, 3, 3 ],
    decoder_conv_t_strides = [ 2, 2, 2, 2 ],
    z_dim = 200,
    use_batch_norm = True,
    use_dropout = True,
    r_loss_factor = 10000
)
In [24]:
vae.encoder.summary()
Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 encoder_input (InputLayer)     [(None, 128, 128, 3  0           []                               
                                )]                                                                
                                                                                                  
 encoder_conv_0 (Conv2D)        (None, 64, 64, 32)   896         ['encoder_input[0][0]']          
                                                                                                  
 batch_normalization (BatchNorm  (None, 64, 64, 32)  128         ['encoder_conv_0[0][0]']         
 alization)                                                                                       
                                                                                                  
 leaky_re_lu (LeakyReLU)        (None, 64, 64, 32)   0           ['batch_normalization[0][0]']    
                                                                                                  
 dropout (Dropout)              (None, 64, 64, 32)   0           ['leaky_re_lu[0][0]']            
                                                                                                  
 encoder_conv_1 (Conv2D)        (None, 32, 32, 64)   18496       ['dropout[0][0]']                
                                                                                                  
 batch_normalization_1 (BatchNo  (None, 32, 32, 64)  256         ['encoder_conv_1[0][0]']         
 rmalization)                                                                                     
                                                                                                  
 leaky_re_lu_1 (LeakyReLU)      (None, 32, 32, 64)   0           ['batch_normalization_1[0][0]']  
                                                                                                  
 dropout_1 (Dropout)            (None, 32, 32, 64)   0           ['leaky_re_lu_1[0][0]']          
                                                                                                  
 encoder_conv_2 (Conv2D)        (None, 16, 16, 64)   36928       ['dropout_1[0][0]']              
                                                                                                  
 batch_normalization_2 (BatchNo  (None, 16, 16, 64)  256         ['encoder_conv_2[0][0]']         
 rmalization)                                                                                     
                                                                                                  
 leaky_re_lu_2 (LeakyReLU)      (None, 16, 16, 64)   0           ['batch_normalization_2[0][0]']  
                                                                                                  
 dropout_2 (Dropout)            (None, 16, 16, 64)   0           ['leaky_re_lu_2[0][0]']          
                                                                                                  
 encoder_conv_3 (Conv2D)        (None, 8, 8, 64)     36928       ['dropout_2[0][0]']              
                                                                                                  
 batch_normalization_3 (BatchNo  (None, 8, 8, 64)    256         ['encoder_conv_3[0][0]']         
 rmalization)                                                                                     
                                                                                                  
 leaky_re_lu_3 (LeakyReLU)      (None, 8, 8, 64)     0           ['batch_normalization_3[0][0]']  
                                                                                                  
 dropout_3 (Dropout)            (None, 8, 8, 64)     0           ['leaky_re_lu_3[0][0]']          
                                                                                                  
 flatten (Flatten)              (None, 4096)         0           ['dropout_3[0][0]']              
                                                                                                  
 mu (Dense)                     (None, 200)          819400      ['flatten[0][0]']                
                                                                                                  
 log_var (Dense)                (None, 200)          819400      ['flatten[0][0]']                
                                                                                                  
 encoder_output (Sampling)      (None, 200)          0           ['mu[0][0]',                     
                                                                  'log_var[0][0]']                
                                                                                                  
==================================================================================================
Total params: 1,732,944
Trainable params: 1,732,496
Non-trainable params: 448
__________________________________________________________________________________________________
In [25]:
vae.decoder.summary()
Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 decoder_input (InputLayer)  [(None, 200)]             0         
                                                                 
 dense (Dense)               (None, 4096)              823296    
                                                                 
 reshape (Reshape)           (None, 8, 8, 64)          0         
                                                                 
 decoder_conv_t_0 (Conv2DTra  (None, 16, 16, 64)       36928     
 nspose)                                                         
                                                                 
 batch_normalization_4 (Batc  (None, 16, 16, 64)       256       
 hNormalization)                                                 
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 16, 16, 64)        0         
                                                                 
 dropout_4 (Dropout)         (None, 16, 16, 64)        0         
                                                                 
 decoder_conv_t_1 (Conv2DTra  (None, 32, 32, 64)       36928     
 nspose)                                                         
                                                                 
 batch_normalization_5 (Batc  (None, 32, 32, 64)       256       
 hNormalization)                                                 
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 32, 32, 64)        0         
                                                                 
 dropout_5 (Dropout)         (None, 32, 32, 64)        0         
                                                                 
 decoder_conv_t_2 (Conv2DTra  (None, 64, 64, 32)       18464     
 nspose)                                                         
                                                                 
 batch_normalization_6 (Batc  (None, 64, 64, 32)       128       
 hNormalization)                                                 
                                                                 
 leaky_re_lu_6 (LeakyReLU)   (None, 64, 64, 32)        0         
                                                                 
 dropout_6 (Dropout)         (None, 64, 64, 32)        0         
                                                                 
 decoder_conv_t_3 (Conv2DTra  (None, 128, 128, 3)      867       
 nspose)                                                         
                                                                 
 activation (Activation)     (None, 128, 128, 3)       0         
                                                                 
=================================================================
Total params: 917,123
Trainable params: 916,803
Non-trainable params: 320
_________________________________________________________________

Training

学習

In [26]:
LEARNING_RATE = 0.0005
In [27]:
MAX_EPOCHS = 4

(1) Training: Use vae.model.fit()

Note that the loss function is not specified at the call of vae.model.compile() function. Since it cannot be calculated simply using y_true and y_pred, the train_step() function of the VAEModel class called from fit() is used to find loss and gradients and train them. The self.optimizer of the VAEModel class referenced in the train_step() function is the optimizer given by the compile() function.

(1) 学習する: vae.model.fit() を使う

vae.model.compile() 関数の呼び出しにおいて、loss関数を指定しないことに注意が必要である。 y_truey_pred を使って単純に計算できないので、fit() から呼び出される VAEModel クラスの train_step() 関数でlossとgradientsを求めて、trainingする。 train_step() 関数の中で参照される VAEModel クラスの self.optimizercompile() 関数で与えられた optimizer である。

In [28]:
save_path1 = '/content/drive/MyDrive/ColabRun/VAE_CelebA01'
In [29]:
optimizer = tf.keras.optimizers.Adam(learning_rate = LEARNING_RATE)
vae.model.compile(optimizer=optimizer)
In [30]:
history = vae.train_generator_with_fit(
    data_flow,
    epochs = 3,
    run_folder = save_path1
)
Epoch 1/3
6015/6015 [==============================] - 292s 48ms/step - loss: 243.5313 - reconstruction_loss: 183.9535 - kl_loss: 59.5779
Epoch 2/3
6015/6015 [==============================] - 263s 44ms/step - loss: 209.1433 - reconstruction_loss: 147.8302 - kl_loss: 61.3132
Epoch 3/3
6015/6015 [==============================] - 262s 44ms/step - loss: 204.8720 - reconstruction_loss: 143.5817 - kl_loss: 61.2904
In [31]:
print(history.history)
{'loss': [208.21795654296875, 212.05111694335938, 214.60659790039062], 'reconstruction_loss': [148.90037536621094, 149.67774963378906, 153.86358642578125], 'kl_loss': [59.31757736206055, 62.37335968017578, 60.74300765991211]}
In [32]:
print(history.history.keys())
dict_keys(['loss', 'reconstruction_loss', 'kl_loss'])
In [33]:
loss1_1 = history.history['loss']
rloss1_1 = history.history['reconstruction_loss']
kloss1_1 = history.history['kl_loss']

Training in addition

Load the saved parameters and model weights, and try training further.

追加の学習

保存してあるパラメータとモデルの重みをロードして、追加の学習を試みる。

In [34]:
# Load the saved parameters and weights.
# 保存してある学習結果をロードする。

vae_work = VariationalAutoEncoder.load(save_path1)

# Display the epoch count of the model.
# training のepoch回数を表示する。

print(vae_work.epoch)
3
In [35]:
# Training in addition
# 追加で training する。

vae_work.model.compile(optimizer)


history2 = vae_work.train_generator_with_fit(
    data_flow,
    epochs = MAX_EPOCHS,
    run_folder = save_path1
)
Epoch 4/4
6015/6015 [==============================] - 265s 44ms/step - loss: 203.5751 - reconstruction_loss: 142.3118 - kl_loss: 61.2633
In [36]:
print(len(history2.history))
3
In [37]:
loss1_2 = history2.history['loss']
rloss1_2 = history2.history['reconstruction_loss']
kloss1_2 = history2.history['kl_loss']

loss1 = np.concatenate([loss1_1, loss1_2], axis=0)
rloss1 = np.concatenate([rloss1_1, rloss1_2], axis=0)
kloss1 = np.concatenate([kloss1_1, kloss1_2], axis=0)
In [38]:
VariationalAutoEncoder.plot_history([loss1, rloss1, kloss1], ['total_loss', 'reconstruct_loss', 'kl_loss'])

Validate Training results

Since the returned value of vae.decoder() is Tensor for the use of <code>@tf.function</code>, it needs to be converted to an array of numpy.

学習結果を検証する

<code>@tf.function</code> 宣言のため vae.decoder() の返り値は Tensor になっているので、numpy の配列に変換する必要がある。

In [39]:
x_, _ = next(val_data_flow)
selected_images = x_[:10]
In [40]:
z_mean, z_log_var, z = vae_work.encoder(selected_images)
reconst_images = vae_work.decoder(z).numpy()  # Convert Tensor to numpy array.

txts = [f'{p[0]:.3f}, {p[1]:.3f}' for p in z ]
In [41]:
%matplotlib inline

VariationalAutoEncoder.showImages(selected_images, reconst_images, txts, 1.4, 1.4)

(2) Training with tf.GradientTape() function.

Instead of using fit(), calculate the loss in your own train() function, find the gradients, and apply them to the variables.

The train_tf() function is speeding up by declaring <code>@tf.function</code> the compute_loss_and_grads() function.

(2) tf.GradientTape() 関数を使った学習

fit() 関数を使わずに、自分で記述した train() 関数内で loss を計算し、gradients を求めて、変数に適用する。

train_tf() 関数では、lossとgradientsの計算を行う compute_loss_and_grads() 関数を <code>@tf.function</code> 宣言することで高速化を図っている。

In [42]:
save_path2 = '/content/drive/MyDrive/ColabRun/VAE_CelebA02/'
In [43]:
from nw.VariationalAutoEncoder import VariationalAutoEncoder

vae2 = VariationalAutoEncoder(
    input_dim = INPUT_DIM,
    encoder_conv_filters = [ 32, 64, 64, 64 ],
    encoder_conv_kernel_size = [ 3, 3, 3, 3 ],
    encoder_conv_strides = [ 2, 2, 2, 2 ],
    decoder_conv_t_filters = [ 64, 64, 32, 3 ],
    decoder_conv_t_kernel_size = [ 3, 3, 3, 3 ],
    decoder_conv_t_strides = [ 2, 2, 2, 2 ],
    z_dim = 200,
    use_batch_norm = True,
    use_dropout = True,
    r_loss_factor = 10000
)
In [44]:
optimizer2 = tf.keras.optimizers.Adam(learning_rate = LEARNING_RATE)
In [45]:
log2_1 = vae2.train_tf_generator(
    data_flow,
    epochs = 3,
    run_folder = save_path2,
    optimizer = optimizer2,
    save_epoch_interval = 50
)
1/3 6015 loss: total 244.037 reconstruction 183.636 kl 60.401 0:04:37.396218
2/3 6015 loss: total 209.520 reconstruction 148.031 kl 61.489 0:09:14.412426
3/3 6015 loss: total 205.202 reconstruction 144.017 kl 61.185 0:13:39.886454
In [46]:
print(log2_1.keys())

loss2_1 = log2_1['loss']
rloss2_1 = log2_1['reconstruction_loss']
kloss2_1 = log2_1['kl_loss']
dict_keys(['loss', 'reconstruction_loss', 'kl_loss'])
In [47]:
# Load the saved parameters and weights.
# 保存したパラメータと重みを読み込む

vae2_work = VariationalAutoEncoder.load(save_path2)
print(vae2_work.epoch)
3
In [48]:
# Train in addition
# 追加で training する。

log2_2 = vae2_work.train_tf_generator(
    data_flow,
    epochs = MAX_EPOCHS,
    run_folder = save_path2,
    optimizer = optimizer2,
    save_epoch_interval=50
)
4/4 6015 loss: total 203.840 reconstruction 142.770 kl 61.070 0:04:23.347881
In [49]:
loss2_2 = log2_2['loss']
rloss2_2 = log2_2['reconstruction_loss']
kloss2_2 = log2_2['kl_loss']
In [50]:
loss2 = np.concatenate([loss2_1, loss2_2], axis=0)
rloss2 = np.concatenate([rloss2_1, rloss2_2], axis=0)
kloss2 = np.concatenate([kloss2_1, kloss2_2], axis=0)
In [51]:
VariationalAutoEncoder.plot_history(
    [loss2], 
    ['total_loss']
)
In [52]:
VariationalAutoEncoder.plot_history(
    [rloss2], 
    ['reconstruction_loss']
)
In [53]:
VariationalAutoEncoder.plot_history(
    [kloss2], 
    ['kl_loss']
)