Updated 21/Nov/2021 by Yoshihisa Nitta  

Further Training of Variational Auto Encoder for CelebA dataset with Tensorflow 2 on Google Colab

Train Variational Auto Encoder further on CelebA dataset. It is assumed that it is in the state after executing VAE_CelebA_Train.ipynb.

CelebA データセットに対して Variational Auto Encoder をGoogle Colab 上の Tensorflow 2 で追加学習する

CelebA データセットに対して変分オートエンコーダをさらに学習させる。 VAE_CelebA_Train.ipynb を実行した後の状態であることを前提としている。

In [1]:
#! pip install tensorflow==2.7.0
In [2]:
%tensorflow_version 2.x

import tensorflow as tf
print(tf.__version__)
2.7.0

Check the Google Colab runtime environment

Google Colab 実行環境を調べる

In [3]:
! nvidia-smi
! cat /proc/cpuinfo
! cat /etc/issue
! free -h
Sun Nov 21 14:45:15 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 79
model name	: Intel(R) Xeon(R) CPU @ 2.20GHz
stepping	: 0
microcode	: 0x1
cpu MHz		: 2199.998
cache size	: 56320 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 0
initial apicid	: 0
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4399.99
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 bits virtual
power management:

processor	: 1
vendor_id	: GenuineIntel
cpu family	: 6
model		: 79
model name	: Intel(R) Xeon(R) CPU @ 2.20GHz
stepping	: 0
microcode	: 0x1
cpu MHz		: 2199.998
cache size	: 56320 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 1
initial apicid	: 1
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4399.99
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 bits virtual
power management:

Ubuntu 18.04.5 LTS \n \l

              total        used        free      shared  buff/cache   available
Mem:            12G        733M          9G        1.2M        2.0G         11G
Swap:            0B          0B          0B

Mount Google Drive from Google Colab

Google Colab から GoogleDrive をマウントする

In [4]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
In [5]:
! ls /content/drive
MyDrive  Shareddrives

Download source file from Google Drive or nw.tsuda.ac.jp

Basically, gdown from Google Drive. Download from nw.tsuda.ac.jp above only if the specifications of Google Drive change and you cannot download from Google Drive.

Google Drive または nw.tsuda.ac.jp からファイルをダウンロードする

基本的に Google Drive から gdown してください。 Google Drive の仕様が変わってダウンロードができない場合にのみ、nw.tsuda.ac.jp からダウンロードしてください。

In [9]:
# Download source file
nw_path = './nw'
! rm -rf {nw_path}
! mkdir -p {nw_path}

if True:   # from Google Drive
    url_model =  'https://drive.google.com/uc?id=1ZCihR7JkMOity4wCr66ZCp-3ZOlfwwo3'
    ! (cd {nw_path}; gdown {url_model})
else:      # from nw.tsuda.ac.jp
    URL_NW = 'https://nw.tsuda.ac.jp/lec/GoogleColab/pub'
    url_model = f'{URL_NW}/models/VariationalAutoEncoder.py'
    ! wget -nd {url_model} -P {nw_path}
Downloading...
From: https://drive.google.com/uc?id=1ZCihR7JkMOity4wCr66ZCp-3ZOlfwwo3
To: /content/nw/VariationalAutoEncoder.py
100% 18.7k/18.7k [00:00<00:00, 16.3MB/s]
In [10]:
! cat {nw_path}/VariationalAutoEncoder.py
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import os
import pickle
import datetime

class Sampling(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, inputs):
        mu, log_var = inputs
        epsilon = tf.keras.backend.random_normal(shape=tf.keras.backend.shape(mu), mean=0., stddev=1.)
        return mu + tf.keras.backend.exp(log_var / 2) * epsilon


class VAEModel(tf.keras.models.Model):
    def __init__(self, encoder, decoder, r_loss_factor, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.r_loss_factor = r_loss_factor


    @tf.function
    def loss_fn(self, x):
        z_mean, z_log_var, z = self.encoder(x)
        reconstruction = self.decoder(z)
        reconstruction_loss = tf.reduce_mean(
            tf.square(x - reconstruction), axis=[1,2,3]
        ) * self.r_loss_factor
        kl_loss = tf.reduce_sum(
            1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var),
            axis = 1
        ) * (-0.5)
        total_loss = reconstruction_loss + kl_loss
        return total_loss, reconstruction_loss, kl_loss


    @tf.function
    def compute_loss_and_grads(self, x):
        with tf.GradientTape() as tape:
            total_loss, reconstruction_loss, kl_loss = self.loss_fn(x)
        grads = tape.gradient(total_loss, self.trainable_weights)
        return total_loss, reconstruction_loss, kl_loss, grads


    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        total_loss, reconstruction_loss, kl_loss, grads = self.compute_loss_and_grads(data)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": tf.math.reduce_mean(total_loss),
            "reconstruction_loss": tf.math.reduce_mean(reconstruction_loss),
            "kl_loss": tf.math.reduce_mean(kl_loss),
        }

    def call(self,inputs):
        _, _, z = self.encoder(inputs)
        return self.decoder(z)


class VariationalAutoEncoder():
    def __init__(self, 
                 input_dim,
                 encoder_conv_filters,
                 encoder_conv_kernel_size,
                 encoder_conv_strides,
                 decoder_conv_t_filters,
                 decoder_conv_t_kernel_size,
                 decoder_conv_t_strides,
                 z_dim,
                 r_loss_factor,   ### added
                 use_batch_norm = False,
                 use_dropout = False,
                 epoch = 0
                ):
        self.name = 'variational_autoencoder'
        self.input_dim = input_dim
        self.encoder_conv_filters = encoder_conv_filters
        self.encoder_conv_kernel_size = encoder_conv_kernel_size
        self.encoder_conv_strides = encoder_conv_strides
        self.decoder_conv_t_filters = decoder_conv_t_filters
        self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size
        self.decoder_conv_t_strides = decoder_conv_t_strides
        self.z_dim = z_dim
        self.r_loss_factor = r_loss_factor   ### added
            
        self.use_batch_norm = use_batch_norm
        self.use_dropout = use_dropout

        self.epoch = epoch
            
        self.n_layers_encoder = len(encoder_conv_filters)
        self.n_layers_decoder = len(decoder_conv_t_filters)
            
        self._build()
 

    def _build(self):
        ### THE ENCODER
        encoder_input = tf.keras.layers.Input(shape=self.input_dim, name='encoder_input')
        x = encoder_input
        
        for i in range(self.n_layers_encoder):
            x = conv_layer = tf.keras.layers.Conv2D(
                filters = self.encoder_conv_filters[i],
                kernel_size = self.encoder_conv_kernel_size[i],
                strides = self.encoder_conv_strides[i],
                padding  = 'same',
                name = 'encoder_conv_' + str(i)
            )(x)

            if self.use_batch_norm:                                ### The order of layers is opposite to AutoEncoder
                x = tf.keras.layers.BatchNormalization()(x)        ###   AE: LeakyReLU -> BatchNorm
            x = tf.keras.layers.LeakyReLU()(x)                     ###   VAE: BatchNorm -> LeakyReLU
            
            if self.use_dropout:
                x = tf.keras.layers.Dropout(rate = 0.25)(x)
        
        shape_before_flattening = tf.keras.backend.int_shape(x)[1:]
        
        x = tf.keras.layers.Flatten()(x)
        
        self.mu = tf.keras.layers.Dense(self.z_dim, name='mu')(x)
        self.log_var = tf.keras.layers.Dense(self.z_dim, name='log_var')(x) 
        self.z = Sampling(name='encoder_output')([self.mu, self.log_var])
        
        self.encoder = tf.keras.models.Model(encoder_input, [self.mu, self.log_var, self.z], name='encoder')
        
        
        ### THE DECODER
        decoder_input = tf.keras.layers.Input(shape=(self.z_dim,), name='decoder_input')
        x = decoder_input
        x = tf.keras.layers.Dense(np.prod(shape_before_flattening))(x)
        x = tf.keras.layers.Reshape(shape_before_flattening)(x)
        
        for i in range(self.n_layers_decoder):
            x = conv_t_layer =   tf.keras.layers.Conv2DTranspose(
                filters = self.decoder_conv_t_filters[i],
                kernel_size = self.decoder_conv_t_kernel_size[i],
                strides = self.decoder_conv_t_strides[i],
                padding = 'same',
                name = 'decoder_conv_t_' + str(i)
            )(x)
            
            if i < self.n_layers_decoder - 1:
                if self.use_batch_norm:                           ### The order of layers is opposite to AutoEncoder
                    x = tf.keras.layers.BatchNormalization()(x)   ###     AE: LeakyReLU -> BatchNorm
                x = tf.keras.layers.LeakyReLU()(x)                ###      VAE: BatchNorm -> LeakyReLU                
                if self.use_dropout:
                    x = tf.keras.layers.Dropout(rate=0.25)(x)
            else:
                x = tf.keras.layers.Activation('sigmoid')(x)
       
        decoder_output = x
        self.decoder = tf.keras.models.Model(decoder_input, decoder_output, name='decoder')  ### added (name)
        
        ### THE FULL AUTOENCODER
        self.model = VAEModel(self.encoder, self.decoder, self.r_loss_factor)
        
        
    def save(self, folder):
        self.save_params(os.path.join(folder, 'params.pkl'))
        self.save_weights(folder)


    @staticmethod
    def load(folder, epoch=None):  # VariationalAutoEncoder.load(folder)
        params = VariationalAutoEncoder.load_params(os.path.join(folder, 'params.pkl'))
        VAE = VariationalAutoEncoder(*params)
        if epoch is None:
            VAE.load_weights(folder)
        else:
            VAE.load_weights(folder, epoch-1)
            VAE.epoch = epoch
        return VAE

        
    def save_params(self, filepath):
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        with open(filepath, 'wb') as f:
            pickle.dump([
                self.input_dim,
                self.encoder_conv_filters,
                self.encoder_conv_kernel_size,
                self.encoder_conv_strides,
                self.decoder_conv_t_filters,
                self.decoder_conv_t_kernel_size,
                self.decoder_conv_t_strides,
                self.z_dim,
                self.r_loss_factor,
                self.use_batch_norm,
                self.use_dropout,
                self.epoch
            ], f)


    @staticmethod
    def load_params(filepath):
        with open(filepath, 'rb') as f:
            params = pickle.load(f)
        return params


    def save_weights(self, folder, epoch=None):
        if epoch is None:
            self.save_model_weights(self.encoder, os.path.join(folder, f'weights/encoder-weights.h5'))
            self.save_model_weights(self.decoder, os.path.join(folder, f'weights/decoder-weights.h5'))
        else:
            self.save_model_weights(self.encoder, os.path.join(folder, f'weights/encoder-weights_{epoch}.h5'))
            self.save_model_weights(self.decoder, os.path.join(folder, f'weights/decoder-weights_{epoch}.h5'))


    def save_model_weights(self, model, filepath):
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        model.save_weights(filepath)


    def load_weights(self, folder, epoch=None):
        if epoch is None:
            self.encoder.load_weights(os.path.join(folder, f'weights/encoder-weights.h5'))
            self.decoder.load_weights(os.path.join(folder, f'weights/decoder-weights.h5'))
        else:
            self.encoder.load_weights(os.path.join(folder, f'weights/encoder-weights_{epoch}.h5'))
            self.decoder.load_weights(os.path.join(folder, f'weights/decoder-weights_{epoch}.h5'))


    def save_images(self, imgs, filepath):
        z_mean, z_log_var, z = self.encoder.predict(imgs)
        reconst_imgs = self.decoder.predict(z)
        txts = [ f'{p[0]:.3f}, {p[1]:.3f}' for p in z ]
        AutoEncoder.showImages(imgs, reconst_imgs, txts, 1.4, 1.4, 0.5, filepath)
      

    def compile(self, learning_rate):
        self.learning_rate = learning_rate
        optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
        self.model.compile(optimizer=optimizer)     # CAUTION!!!: loss(y_true, y_pred) function is not specified.
        
        
    def train_with_fit(
            self,
            x_train,
            batch_size,
            epochs,
            run_folder='run/'
    ):
        history = self.model.fit(
            x_train,
            x_train,
            batch_size = batch_size,
            shuffle=True,
            initial_epoch = self.epoch,
            epochs = epochs
        )
        if (self.epoch < epochs):
            self.epoch = epochs

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(run_folder, self.epoch-1)
        
        return history


    def train_generator_with_fit(
            self,
            data_flow,
            epochs,
            run_folder='run/'
    ):
        history = self.model.fit(
            data_flow,
            initial_epoch = self.epoch,
            epochs = epochs
        )
        if (self.epoch < epochs):
            self.epoch = epochs

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(run_folder, self.epoch-1)
        
        return history


    def train_tf(
            self,
            x_train,
            batch_size = 32,
            epochs = 10,
            shuffle = False,
            run_folder = 'run/',
            optimizer = None,
            save_epoch_interval = 100,
            validation_data = None
    ):
        start_time = datetime.datetime.now()
        steps = x_train.shape[0] // batch_size

        total_losses = []
        reconstruction_losses = []
        kl_losses = []

        val_total_losses = []
        val_reconstruction_losses = []
        val_kl_losses = []

        for epoch in range(self.epoch, epochs):
            epoch_loss = 0
            indices = tf.range(x_train.shape[0], dtype=tf.int32)
            if shuffle:
                indices = tf.random.shuffle(indices)
            x_ = x_train[indices]

            step_total_losses = []
            step_reconstruction_losses = []
            step_kl_losses = []
            for step in range(steps):
                start = batch_size * step
                end = start + batch_size

                total_loss, reconstruction_loss, kl_loss, grads = self.model.compute_loss_and_grads(x_[start:end])
                optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
                
                step_total_losses.append(np.mean(total_loss))
                step_reconstruction_losses.append(np.mean(reconstruction_loss))
                step_kl_losses.append(np.mean(kl_loss))
            
            epoch_total_loss = np.mean(step_total_losses)
            epoch_reconstruction_loss = np.mean(step_reconstruction_losses)
            epoch_kl_loss = np.mean(step_kl_losses)

            total_losses.append(epoch_total_loss)
            reconstruction_losses.append(epoch_reconstruction_loss)
            kl_losses.append(epoch_kl_loss)

            val_str = ''
            if not validation_data is None:
                x_val = validation_data
                tl, rl, kl = self.model.loss_fn(x_val)
                val_tl = np.mean(tl)
                val_rl = np.mean(rl)
                val_kl = np.mean(kl)
                val_total_losses.append(val_tl)
                val_reconstruction_losses.append(val_rl)
                val_kl_losses.append(val_kl)
                val_str = f'val loss total {val_tl:.3f} reconstruction {val_rl:.3f} kl {val_kl:.3f} '

            if (epoch+1) % save_epoch_interval == 0 and run_folder != None:
                self.save(run_folder)
                self.save_weights(run_folder, self.epoch)

            elapsed_time = datetime.datetime.now() - start_time
            print(f'{epoch+1}/{epochs} {steps} loss: total {epoch_total_loss:.3f} reconstruction {epoch_reconstruction_loss:.3f} kl {epoch_kl_loss:.3f} {val_str}{elapsed_time}')

            self.epoch += 1

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(run_folder, self.epoch-1)

        dic = { 'loss' : total_losses, 'reconstruction_loss' : reconstruction_losses, 'kl_loss' : kl_losses }
        if not validation_data is None:
            dic['val_loss'] = val_total_losses
            dic['val_reconstruction_loss'] = val_reconstruction_losses
            dic['val_kl_loss'] = val_kl_losses

        return dic
            

    def train_tf_generator(
            self,
            data_flow,
            epochs = 10,
            run_folder = 'run/',
            optimizer = None,
            save_epoch_interval = 100,
            validation_data_flow = None
    ):
        start_time = datetime.datetime.now()
        steps = len(data_flow)

        total_losses = []
        reconstruction_losses = []
        kl_losses = []

        val_total_losses = []
        val_reconstruction_losses = []
        val_kl_losses = []

        for epoch in range(self.epoch, epochs):
            epoch_loss = 0

            step_total_losses = []
            step_reconstruction_losses = []
            step_kl_losses = []

            for step in range(steps):
                x, _ = next(data_flow)

                total_loss, reconstruction_loss, kl_loss, grads = self.model.compute_loss_and_grads(x)
                optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
                
                step_total_losses.append(np.mean(total_loss))
                step_reconstruction_losses.append(np.mean(reconstruction_loss))
                step_kl_losses.append(np.mean(kl_loss))
            
            epoch_total_loss = np.mean(step_total_losses)
            epoch_reconstruction_loss = np.mean(step_reconstruction_losses)
            epoch_kl_loss = np.mean(step_kl_losses)

            total_losses.append(epoch_total_loss)
            reconstruction_losses.append(epoch_reconstruction_loss)
            kl_losses.append(epoch_kl_loss)

            val_str = ''
            if not validation_data_flow is None:
                step_val_tl = []
                step_val_rl = []
                step_val_kl = []
                for i in range(len(validation_data_flow)):
                    x, _ = next(validation_data_flow)
                    tl, rl, kl = self.model.loss_fn(x)
                    step_val_tl.append(np.mean(tl))
                    step_val_rl.append(np.mean(rl))
                    step_val_kl.append(np.mean(kl))
                val_tl = np.mean(step_val_tl)
                val_rl = np.mean(step_val_rl)
                val_kl = np.mean(step_val_kl)
                val_total_losses.append(val_tl)
                val_reconstruction_losses.append(val_rl)
                val_kl_losses.append(val_kl)
                val_str = f'val loss total {val_tl:.3f} reconstruction {val_rl:.3f} kl {val_kl:.3f} '

            if (epoch+1) % save_epoch_interval == 0 and run_folder != None:
                self.save(run_folder)
                self.save_weights(run_folder, self.epoch)

            elapsed_time = datetime.datetime.now() - start_time
            print(f'{epoch+1}/{epochs} {steps} loss: total {epoch_total_loss:.3f} reconstruction {epoch_reconstruction_loss:.3f} kl {epoch_kl_loss:.3f} {val_str}{elapsed_time}')

            self.epoch += 1

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(run_folder, self.epoch-1)

        dic = { 'loss' : total_losses, 'reconstruction_loss' : reconstruction_losses, 'kl_loss' : kl_losses }
        if not validation_data_flow is None:
            dic['val_loss'] = val_total_losses
            dic['val_reconstruction_loss'] = val_reconstruction_losses
            dic['val_kl_loss'] = val_kl_losses

        return dic


    @staticmethod
    def showImages(imgs1, imgs2, txts, w, h, vskip=0.5, filepath=None):
        n = len(imgs1)
        fig, ax = plt.subplots(2, n, figsize=(w * n, (2+vskip) * h))
        for i in range(n):
            if n == 1:
                axis = ax[0]
            else:
                axis = ax[0][i]
            img = imgs1[i].squeeze()
            axis.imshow(img, cmap='gray_r')
            axis.axis('off')

            axis.text(0.5, -0.35, txts[i], fontsize=10, ha='center', transform=axis.transAxes)

            if n == 1:
                axis = ax[1]
            else:
                axis = ax[1][i]
            img2 = imgs2[i].squeeze()
            axis.imshow(img2, cmap='gray_r')
            axis.axis('off')

        if not filepath is None:
            dpath, fname = os.path.split(filepath)
            if dpath != '' and not os.path.exists(dpath):
                os.makedirs(dpath)
            fig.savefig(filepath, dpi=600)
            plt.close()
        else:
            plt.show()

    @staticmethod
    def plot_history(vals, labels):
        colors = ['red', 'blue', 'green', 'orange', 'black', 'pink']
        n = len(vals)
        fig, ax = plt.subplots(1, 1, figsize=(9,4))
        for i in range(n):
            ax.plot(vals[i], c=colors[i], label=labels[i])
        ax.legend(loc='upper right')
        ax.set_xlabel('epochs')
        # ax[0].set_ylabel('loss')
        
        plt.show()

Preparing CelebA dataset

Official WWW of CelebA dataset: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

Google Drive of CelebA dataset: https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg

img_align_celeba.zip mirrored on my Google Drive:
https://drive.google.com/uc?id=1LFKeoI-hb96jlV0K10dO1o04iQPBoFdx

CelebA データセットを用意する

CelebA データセットの公式ページ: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

CelebA データセットのGoogle Drive: https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg

自分の Google Drive 上にミラーした img_align_celeba.zip:
https://drive.google.com/uc?id=1LFKeoI-hb96jlV0K10dO1o04iQPBoFdx

In [11]:
# Download img_align_celeba.zip from GoogleDrive

MIRRORED_URL = 'https://drive.google.com/uc?id=1LFKeoI-hb96jlV0K10dO1o04iQPBoFdx'

! gdown {MIRRORED_URL}
Downloading...
From: https://drive.google.com/uc?id=1LFKeoI-hb96jlV0K10dO1o04iQPBoFdx
To: /content/img_align_celeba.zip
100% 1.44G/1.44G [00:06<00:00, 238MB/s]
In [12]:
! ls -l
total 1409676
drwx------ 6 root root       4096 Nov 21 14:46 drive
-rw-r--r-- 1 root root 1443490838 Nov 21 14:49 img_align_celeba.zip
drwxr-xr-x 2 root root       4096 Nov 21 14:48 nw
drwxr-xr-x 1 root root       4096 Nov 18 14:36 sample_data
In [13]:
DATA_DIR = 'data'
DATA_SUBDIR = 'img_align_celeba'
In [14]:
! rm -rf {DATA_DIR}
! unzip -d {DATA_DIR} -q {DATA_SUBDIR}.zip
In [15]:
! ls -l {DATA_DIR}/{DATA_SUBDIR} | head
! ls {DATA_DIR}/{DATA_SUBDIR} | wc
total 1737936
-rw-r--r-- 1 root root 11440 Sep 28  2015 000001.jpg
-rw-r--r-- 1 root root  7448 Sep 28  2015 000002.jpg
-rw-r--r-- 1 root root  4253 Sep 28  2015 000003.jpg
-rw-r--r-- 1 root root 10747 Sep 28  2015 000004.jpg
-rw-r--r-- 1 root root  6351 Sep 28  2015 000005.jpg
-rw-r--r-- 1 root root  8073 Sep 28  2015 000006.jpg
-rw-r--r-- 1 root root  8203 Sep 28  2015 000007.jpg
-rw-r--r-- 1 root root  7725 Sep 28  2015 000008.jpg
-rw-r--r-- 1 root root  8641 Sep 28  2015 000009.jpg
 202599  202599 2228589

Check the CelebA dataset

CelebA データセットを確認する

In [18]:
# paths to all the image files.

import os
import glob
import numpy as np

all_file_paths = np.array(glob.glob(os.path.join(DATA_DIR, DATA_SUBDIR, '*.jpg')))
n_all_images = len(all_file_paths)

print(n_all_images)
202599
In [19]:
# slect some image files.

n_to_show = 10
selected_indices = np.random.choice(range(n_all_images), n_to_show)
selected_paths = all_file_paths[selected_indices]
In [20]:
# Display some images.
%matplotlib inline
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, n_to_show, figsize=(1.4 * n_to_show, 1.4))
for i, path in enumerate(selected_paths):
    img = tf.keras.preprocessing.image.load_img(path)
    ax[i].imshow(img)
    ax[i].axis('off')
plt.show()

Separate image files for train and test

画像ファイルを学習用とテスト用に分割する

In [16]:
TRAIN_DATA_DIR = 'train_data'
TEST_DATA_DIR = 'test_data'
In [21]:
import os

split = 0.05

indices = np.arange(n_all_images)
np.random.shuffle(indices)
train_indices = indices[: -int(n_all_images * split)]
test_indices = indices[-int(n_all_images * split):]

! rm -rf {TRAIN_DATA_DIR} {TEST_DATA_DIR}

dst=f'{TRAIN_DATA_DIR}/celeba'
if not os.path.exists(dst):
    os.makedirs(dst)
for idx in train_indices:
    path = all_file_paths[idx]
    dpath, fname = os.path.split(path)
    os.symlink(f'../../{path}', f'{dst}/{fname}')

dst=f'{TEST_DATA_DIR}/celeba'
if not os.path.exists(dst):
    os.makedirs(dst)
for idx in test_indices:
    path = all_file_paths[idx]
    dpath, fname = os.path.split(path)
    os.symlink(f'../../{path}', f'{dst}/{fname}')

Prepare ImageDataGenerator

flow_from_directory() requires to specify the parent directory of the directory where the image files are located.

ImageDataGenerator を用意する

flow_from_directory() では image files があるディレクトリの親ディレクトリを指定する必要がある。

In [22]:
INPUT_DIM = (128, 128, 3)
BATCH_SIZE = 32
In [23]:
data_gen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale = 1.0/255
    )

data_flow = data_gen.flow_from_directory(
    TRAIN_DATA_DIR,
    target_size = INPUT_DIM[:2],
    batch_size = BATCH_SIZE,
    shuffle=True,
    class_mode = 'input'
    )

val_data_flow = data_gen.flow_from_directory(
    TEST_DATA_DIR,
    target_size = INPUT_DIM[:2],
    batch_size = BATCH_SIZE,
    shuffle=True,
    class_mode = 'input'
    )
Found 192470 images belonging to 1 classes.
Found 10129 images belonging to 1 classes.
In [24]:
print(len(data_flow))
print(len(val_data_flow))
6015
317

Load the Neural Network Model trained before

Load the model trained by the '(3) Training' method of VAE_CelebA_Train.ipynb.

学習済みのニューラルネットワーク・モデルをロードする

VAE_CelebA_Train.ipynb の 「(3) 学習」方法で学習したモデルをロードする。

In [26]:
save_path3 = '/content/drive/MyDrive/ColabRun/VAE_CelebA03/'
In [47]:
# Load the parameters and model weights saved before
# 保存したパラメータと重みを読み込む

from nw.VariationalAutoEncoder import VariationalAutoEncoder

vae3 = VariationalAutoEncoder.load(save_path3)
print(vae3.epoch)
4

Load the saved loss transition of training before

以前の学習で保存した loss の遷移をロードする

In [48]:
import os
import pickle

var_path = f'{save_path3}/loss_{vae3.epoch-1}.pkl'

with open(var_path, 'rb') as f:
    loss3_1, rloss3_1, kloss3_1, val_loss3_1, val_rloss3_1, val_kloss3_1 = pickle.load(f)
In [49]:
print(len(loss3_1))
4

Train in addition

追加で学習する

In [50]:
LEARNING_RATE = 0.0005
In [51]:
# initial_learning_rate * decay_rate ^ (step // decay_steps)

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate = LEARNING_RATE,
    decay_steps = len(data_flow),
    decay_rate=0.96
)

optimizer3 = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
In [52]:
log3_2 = vae3.train_tf_generator(
    data_flow,
    epochs = 200,
    run_folder = save_path3,
    optimizer = optimizer3,
    save_epoch_interval = 50,
    validation_data_flow = val_data_flow
)
5/200 6015 loss: total 200.568 reconstruction 139.167 kl 61.402 val loss total 199.716 reconstruction 138.757 kl 60.959 0:04:32.413829
6/200 6015 loss: total 199.201 reconstruction 137.740 kl 61.461 val loss total 198.803 reconstruction 135.566 kl 63.236 0:09:00.840572
7/200 6015 loss: total 198.084 reconstruction 136.619 kl 61.464 val loss total 197.773 reconstruction 138.099 kl 59.674 0:13:28.770028
8/200 6015 loss: total 197.448 reconstruction 135.976 kl 61.471 val loss total 197.090 reconstruction 135.068 kl 62.023 0:17:56.283880
9/200 6015 loss: total 196.674 reconstruction 135.245 kl 61.429 val loss total 196.911 reconstruction 136.293 kl 60.618 0:22:23.191827
10/200 6015 loss: total 196.214 reconstruction 134.815 kl 61.398 val loss total 195.840 reconstruction 134.457 kl 61.383 0:26:52.411447
11/200 6015 loss: total 195.389 reconstruction 134.051 kl 61.338 val loss total 195.385 reconstruction 134.659 kl 60.726 0:31:21.585673
12/200 6015 loss: total 195.514 reconstruction 134.116 kl 61.398 val loss total 195.432 reconstruction 133.572 kl 61.859 0:35:50.644285
13/200 6015 loss: total 194.981 reconstruction 133.609 kl 61.371 val loss total 194.876 reconstruction 134.686 kl 60.190 0:40:18.560140
14/200 6015 loss: total 194.501 reconstruction 133.142 kl 61.360 val loss total 194.673 reconstruction 133.745 kl 60.929 0:44:45.301768
15/200 6015 loss: total 194.389 reconstruction 132.984 kl 61.405 val loss total 194.404 reconstruction 132.935 kl 61.469 0:49:11.734210
16/200 6015 loss: total 194.060 reconstruction 132.721 kl 61.340 val loss total 194.477 reconstruction 132.656 kl 61.821 0:53:38.312727
17/200 6015 loss: total 193.802 reconstruction 132.420 kl 61.381 val loss total 194.061 reconstruction 132.766 kl 61.296 0:58:04.671105
18/200 6015 loss: total 193.661 reconstruction 132.288 kl 61.373 val loss total 193.420 reconstruction 132.419 kl 61.000 1:02:28.818799
19/200 6015 loss: total 193.363 reconstruction 132.005 kl 61.358 val loss total 193.682 reconstruction 131.354 kl 62.328 1:06:54.455887
20/200 6015 loss: total 193.203 reconstruction 131.853 kl 61.350 val loss total 193.382 reconstruction 132.363 kl 61.019 1:11:20.673511
21/200 6015 loss: total 193.032 reconstruction 131.673 kl 61.360 val loss total 193.215 reconstruction 132.211 kl 61.004 1:15:47.300723
22/200 6015 loss: total 192.995 reconstruction 131.634 kl 61.361 val loss total 193.328 reconstruction 131.712 kl 61.617 1:20:13.540286
23/200 6015 loss: total 192.744 reconstruction 131.357 kl 61.387 val loss total 192.852 reconstruction 131.338 kl 61.514 1:24:39.321436
24/200 6015 loss: total 192.598 reconstruction 131.199 kl 61.399 val loss total 192.727 reconstruction 131.283 kl 61.444 1:29:06.260637
25/200 6015 loss: total 192.436 reconstruction 131.113 kl 61.323 val loss total 192.670 reconstruction 131.508 kl 61.162 1:33:34.268558
26/200 6015 loss: total 192.299 reconstruction 130.964 kl 61.335 val loss total 192.654 reconstruction 131.610 kl 61.044 1:38:02.271351
27/200 6015 loss: total 192.279 reconstruction 130.894 kl 61.385 val loss total 192.524 reconstruction 131.013 kl 61.511 1:42:29.732575
28/200 6015 loss: total 192.090 reconstruction 130.729 kl 61.361 val loss total 192.298 reconstruction 131.387 kl 60.911 1:46:56.482835
29/200 6015 loss: total 191.990 reconstruction 130.627 kl 61.363 val loss total 192.038 reconstruction 130.568 kl 61.470 1:51:22.718647
30/200 6015 loss: total 191.945 reconstruction 130.586 kl 61.359 val loss total 191.994 reconstruction 130.083 kl 61.911 1:55:49.025294
31/200 6015 loss: total 191.772 reconstruction 130.404 kl 61.368 val loss total 192.161 reconstruction 131.505 kl 60.656 2:00:16.566865
32/200 6015 loss: total 191.608 reconstruction 130.271 kl 61.337 val loss total 191.926 reconstruction 130.664 kl 61.262 2:04:44.523716
33/200 6015 loss: total 191.654 reconstruction 130.267 kl 61.387 val loss total 191.847 reconstruction 129.996 kl 61.851 2:09:12.634162
34/200 6015 loss: total 191.527 reconstruction 130.181 kl 61.346 val loss total 191.904 reconstruction 131.380 kl 60.524 2:13:38.871422
35/200 6015 loss: total 191.435 reconstruction 130.085 kl 61.350 val loss total 191.677 reconstruction 130.091 kl 61.586 2:18:05.688359
36/200 6015 loss: total 191.301 reconstruction 129.911 kl 61.390 val loss total 191.597 reconstruction 129.989 kl 61.608 2:22:30.800486
37/200 6015 loss: total 191.411 reconstruction 130.028 kl 61.384 val loss total 191.686 reconstruction 130.779 kl 60.907 2:26:55.185918
38/200 6015 loss: total 191.238 reconstruction 129.885 kl 61.353 val loss total 191.571 reconstruction 130.240 kl 61.330 2:31:23.018640
39/200 6015 loss: total 191.177 reconstruction 129.811 kl 61.366 val loss total 191.513 reconstruction 130.113 kl 61.400 2:35:51.144367
40/200 6015 loss: total 191.087 reconstruction 129.738 kl 61.349 val loss total 191.582 reconstruction 129.551 kl 62.031 2:40:18.046345
41/200 6015 loss: total 191.020 reconstruction 129.638 kl 61.382 val loss total 191.349 reconstruction 130.358 kl 60.991 2:44:45.438834
42/200 6015 loss: total 191.090 reconstruction 129.730 kl 61.360 val loss total 191.228 reconstruction 130.266 kl 60.962 2:49:13.116803
43/200 6015 loss: total 190.814 reconstruction 129.448 kl 61.366 val loss total 191.296 reconstruction 129.657 kl 61.639 2:53:40.482109
44/200 6015 loss: total 190.916 reconstruction 129.537 kl 61.378 val loss total 191.235 reconstruction 130.130 kl 61.105 2:58:07.187151
45/200 6015 loss: total 190.848 reconstruction 129.470 kl 61.378 val loss total 191.184 reconstruction 130.094 kl 61.089 3:02:33.587573
46/200 6015 loss: total 190.755 reconstruction 129.387 kl 61.369 val loss total 191.292 reconstruction 129.986 kl 61.306 3:07:00.933575
47/200 6015 loss: total 190.729 reconstruction 129.384 kl 61.345 val loss total 191.018 reconstruction 129.780 kl 61.238 3:11:28.767333
48/200 6015 loss: total 190.746 reconstruction 129.366 kl 61.379 val loss total 191.128 reconstruction 129.486 kl 61.642 3:15:56.930912
49/200 6015 loss: total 190.665 reconstruction 129.285 kl 61.380 val loss total 191.122 reconstruction 129.899 kl 61.223 3:20:25.891575
50/200 6015 loss: total 190.568 reconstruction 129.216 kl 61.353 val loss total 191.062 reconstruction 129.415 kl 61.647 3:24:54.822287
51/200 6015 loss: total 190.507 reconstruction 129.187 kl 61.320 val loss total 191.045 reconstruction 129.658 kl 61.387 3:29:20.952387
52/200 6015 loss: total 190.579 reconstruction 129.189 kl 61.390 val loss total 190.926 reconstruction 129.437 kl 61.489 3:33:49.332335
53/200 6015 loss: total 190.384 reconstruction 129.077 kl 61.307 val loss total 191.009 reconstruction 130.127 kl 60.882 3:38:17.467712
54/200 6015 loss: total 190.604 reconstruction 129.210 kl 61.394 val loss total 190.829 reconstruction 129.964 kl 60.864 3:42:44.749036
55/200 6015 loss: total 190.344 reconstruction 128.994 kl 61.350 val loss total 190.877 reconstruction 129.566 kl 61.311 3:47:11.161280
56/200 6015 loss: total 190.405 reconstruction 129.000 kl 61.405 val loss total 190.767 reconstruction 129.130 kl 61.637 3:51:35.811515
57/200 6015 loss: total 190.465 reconstruction 129.067 kl 61.398 val loss total 190.881 reconstruction 129.565 kl 61.316 3:55:59.106461
58/200 6015 loss: total 190.271 reconstruction 128.893 kl 61.377 val loss total 190.640 reconstruction 129.392 kl 61.248 4:00:24.005876
59/200 6015 loss: total 190.287 reconstruction 128.895 kl 61.391 val loss total 190.646 reconstruction 129.648 kl 60.998 4:04:53.825522
60/200 6015 loss: total 190.259 reconstruction 128.923 kl 61.336 val loss total 190.699 reconstruction 129.426 kl 61.273 4:09:19.352547
61/200 6015 loss: total 190.391 reconstruction 128.953 kl 61.437 val loss total 190.661 reconstruction 129.518 kl 61.143 4:13:42.718273
62/200 6015 loss: total 190.085 reconstruction 128.754 kl 61.332 val loss total 190.731 reconstruction 129.224 kl 61.508 4:18:06.786141
63/200 6015 loss: total 190.284 reconstruction 128.896 kl 61.388 val loss total 190.652 reconstruction 129.701 kl 60.951 4:22:31.262709
64/200 6015 loss: total 190.181 reconstruction 128.807 kl 61.374 val loss total 190.714 reconstruction 129.725 kl 60.989 4:26:55.542714
65/200 6015 loss: total 190.247 reconstruction 128.828 kl 61.419 val loss total 190.565 reconstruction 129.282 kl 61.282 4:31:19.501985
66/200 6015 loss: total 189.989 reconstruction 128.650 kl 61.339 val loss total 190.573 reconstruction 129.545 kl 61.028 4:35:43.661270
67/200 6015 loss: total 190.212 reconstruction 128.829 kl 61.383 val loss total 190.437 reconstruction 129.229 kl 61.208 4:40:07.482412
68/200 6015 loss: total 190.242 reconstruction 128.828 kl 61.414 val loss total 190.537 reconstruction 129.258 kl 61.280 4:44:31.648131
69/200 6015 loss: total 189.968 reconstruction 128.595 kl 61.373 val loss total 190.497 reconstruction 129.276 kl 61.221 4:48:58.191145
70/200 6015 loss: total 189.893 reconstruction 128.568 kl 61.325 val loss total 190.617 reconstruction 129.681 kl 60.936 4:53:23.222911
71/200 6015 loss: total 190.060 reconstruction 128.687 kl 61.373 val loss total 190.534 reconstruction 129.128 kl 61.406 4:57:49.206424
72/200 6015 loss: total 190.151 reconstruction 128.705 kl 61.446 val loss total 190.653 reconstruction 129.074 kl 61.579 5:02:15.506579
73/200 6015 loss: total 189.942 reconstruction 128.573 kl 61.370 val loss total 190.606 reconstruction 129.420 kl 61.187 5:06:41.421988
74/200 6015 loss: total 190.042 reconstruction 128.656 kl 61.386 val loss total 190.397 reconstruction 128.845 kl 61.553 5:11:06.739312
75/200 6015 loss: total 190.055 reconstruction 128.641 kl 61.414 val loss total 190.490 reconstruction 129.220 kl 61.270 5:15:32.529343
76/200 6015 loss: total 189.899 reconstruction 128.531 kl 61.368 val loss total 190.348 reconstruction 129.183 kl 61.165 5:19:57.000242
77/200 6015 loss: total 189.988 reconstruction 128.583 kl 61.406 val loss total 190.338 reconstruction 129.166 kl 61.172 5:24:22.411602
78/200 6015 loss: total 189.957 reconstruction 128.595 kl 61.362 val loss total 190.424 reconstruction 128.773 kl 61.650 5:28:47.330339
79/200 6015 loss: total 189.978 reconstruction 128.599 kl 61.379 val loss total 190.363 reconstruction 129.127 kl 61.236 5:33:12.213369
80/200 6015 loss: total 189.971 reconstruction 128.571 kl 61.401 val loss total 190.510 reconstruction 128.927 kl 61.583 5:37:38.632425
81/200 6015 loss: total 189.805 reconstruction 128.409 kl 61.396 val loss total 190.507 reconstruction 129.535 kl 60.972 5:42:07.472294
82/200 6015 loss: total 189.916 reconstruction 128.519 kl 61.397 val loss total 190.335 reconstruction 128.942 kl 61.393 5:46:32.956458
83/200 6015 loss: total 189.889 reconstruction 128.506 kl 61.384 val loss total 190.397 reconstruction 128.945 kl 61.452 5:50:58.041977
84/200 6015 loss: total 189.779 reconstruction 128.420 kl 61.359 val loss total 190.545 reconstruction 129.182 kl 61.362 5:55:21.654447
85/200 6015 loss: total 189.893 reconstruction 128.491 kl 61.402 val loss total 190.270 reconstruction 129.081 kl 61.189 5:59:44.539797
86/200 6015 loss: total 189.915 reconstruction 128.508 kl 61.407 val loss total 190.383 reconstruction 129.041 kl 61.341 6:04:08.094860
87/200 6015 loss: total 189.923 reconstruction 128.493 kl 61.430 val loss total 190.296 reconstruction 128.828 kl 61.468 6:08:31.680567
88/200 6015 loss: total 189.728 reconstruction 128.412 kl 61.316 val loss total 190.390 reconstruction 128.930 kl 61.461 6:12:55.308221
89/200 6015 loss: total 189.753 reconstruction 128.371 kl 61.382 val loss total 190.242 reconstruction 129.057 kl 61.185 6:17:17.191232
90/200 6015 loss: total 189.906 reconstruction 128.517 kl 61.389 val loss total 190.314 reconstruction 129.019 kl 61.295 6:21:39.701207
91/200 6015 loss: total 189.687 reconstruction 128.319 kl 61.368 val loss total 190.315 reconstruction 128.992 kl 61.323 6:26:03.115166
92/200 6015 loss: total 189.936 reconstruction 128.534 kl 61.402 val loss total 190.474 reconstruction 129.380 kl 61.094 6:30:26.444211
93/200 6015 loss: total 189.839 reconstruction 128.451 kl 61.388 val loss total 190.211 reconstruction 128.876 kl 61.335 6:34:48.766150
94/200 6015 loss: total 189.798 reconstruction 128.380 kl 61.418 val loss total 190.267 reconstruction 128.980 kl 61.288 6:39:10.425578
95/200 6015 loss: total 189.772 reconstruction 128.400 kl 61.372 val loss total 190.244 reconstruction 128.668 kl 61.576 6:43:33.181423
96/200 6015 loss: total 189.741 reconstruction 128.340 kl 61.401 val loss total 190.293 reconstruction 128.969 kl 61.323 6:47:56.068890
97/200 6015 loss: total 189.911 reconstruction 128.503 kl 61.407 val loss total 190.332 reconstruction 129.013 kl 61.319 6:52:20.228992
98/200 6015 loss: total 189.648 reconstruction 128.237 kl 61.411 val loss total 190.310 reconstruction 129.169 kl 61.142 6:56:42.519476
99/200 6015 loss: total 189.812 reconstruction 128.420 kl 61.392 val loss total 190.221 reconstruction 128.764 kl 61.457 7:01:05.204807
100/200 6015 loss: total 189.769 reconstruction 128.359 kl 61.409 val loss total 190.295 reconstruction 129.033 kl 61.262 7:05:29.354277
101/200 6015 loss: total 189.844 reconstruction 128.414 kl 61.430 val loss total 190.202 reconstruction 129.122 kl 61.081 7:09:52.537646
102/200 6015 loss: total 189.569 reconstruction 128.217 kl 61.352 val loss total 190.395 reconstruction 128.959 kl 61.436 7:14:15.736100
103/200 6015 loss: total 189.820 reconstruction 128.416 kl 61.404 val loss total 190.201 reconstruction 129.009 kl 61.192 7:18:38.017203
104/200 6015 loss: total 189.712 reconstruction 128.319 kl 61.393 val loss total 190.145 reconstruction 128.963 kl 61.182 7:22:59.870123
105/200 6015 loss: total 189.716 reconstruction 128.334 kl 61.382 val loss total 190.130 reconstruction 128.908 kl 61.222 7:27:20.508035
106/200 6015 loss: total 189.613 reconstruction 128.193 kl 61.420 val loss total 190.171 reconstruction 129.157 kl 61.014 7:31:40.944213
107/200 6015 loss: total 189.906 reconstruction 128.501 kl 61.405 val loss total 190.118 reconstruction 128.818 kl 61.300 7:36:01.134981
108/200 6015 loss: total 189.753 reconstruction 128.327 kl 61.426 val loss total 190.263 reconstruction 128.958 kl 61.305 7:40:22.268277
109/200 6015 loss: total 189.611 reconstruction 128.245 kl 61.366 val loss total 190.186 reconstruction 128.910 kl 61.276 7:44:45.580816
110/200 6015 loss: total 189.852 reconstruction 128.412 kl 61.440 val loss total 190.289 reconstruction 128.885 kl 61.405 7:49:06.185749
111/200 6015 loss: total 189.757 reconstruction 128.358 kl 61.399 val loss total 190.200 reconstruction 128.827 kl 61.373 7:53:27.280182
112/200 6015 loss: total 189.616 reconstruction 128.255 kl 61.361 val loss total 190.327 reconstruction 129.196 kl 61.130 7:57:46.656329
113/200 6015 loss: total 189.735 reconstruction 128.362 kl 61.373 val loss total 190.198 reconstruction 129.048 kl 61.150 8:02:06.877047
114/200 6015 loss: total 189.620 reconstruction 128.217 kl 61.403 val loss total 190.324 reconstruction 129.152 kl 61.172 8:06:27.440477
115/200 6015 loss: total 189.780 reconstruction 128.398 kl 61.382 val loss total 190.091 reconstruction 128.867 kl 61.224 8:10:47.626755
116/200 6015 loss: total 189.620 reconstruction 128.227 kl 61.394 val loss total 190.210 reconstruction 128.808 kl 61.402 8:15:07.837734
117/200 6015 loss: total 189.679 reconstruction 128.294 kl 61.385 val loss total 190.079 reconstruction 128.799 kl 61.280 8:19:27.215123
118/200 6015 loss: total 189.642 reconstruction 128.260 kl 61.382 val loss total 190.313 reconstruction 128.958 kl 61.356 8:23:47.512881
119/200 6015 loss: total 189.667 reconstruction 128.274 kl 61.393 val loss total 190.110 reconstruction 128.708 kl 61.401 8:28:07.862566
120/200 6015 loss: total 189.722 reconstruction 128.315 kl 61.406 val loss total 190.133 reconstruction 129.061 kl 61.072 8:32:32.216067
121/200 6015 loss: total 189.717 reconstruction 128.324 kl 61.393 val loss total 190.154 reconstruction 128.917 kl 61.237 8:36:54.738705
122/200 6015 loss: total 189.657 reconstruction 128.260 kl 61.397 val loss total 190.205 reconstruction 128.821 kl 61.385 8:41:16.586956
123/200 6015 loss: total 189.654 reconstruction 128.288 kl 61.366 val loss total 190.243 reconstruction 128.887 kl 61.356 8:45:37.854442
124/200 6015 loss: total 189.752 reconstruction 128.324 kl 61.428 val loss total 190.109 reconstruction 128.916 kl 61.193 8:49:59.454717
125/200 6015 loss: total 189.613 reconstruction 128.243 kl 61.370 val loss total 190.166 reconstruction 128.825 kl 61.340 8:54:22.846552
126/200 6015 loss: total 189.624 reconstruction 128.245 kl 61.379 val loss total 190.251 reconstruction 129.030 kl 61.222 8:58:45.302308
127/200 6015 loss: total 189.768 reconstruction 128.370 kl 61.398 val loss total 190.025 reconstruction 128.583 kl 61.442 9:03:08.031802
128/200 6015 loss: total 189.625 reconstruction 128.263 kl 61.363 val loss total 190.229 reconstruction 128.919 kl 61.310 9:07:30.957156
129/200 6015 loss: total 189.613 reconstruction 128.224 kl 61.390 val loss total 190.118 reconstruction 128.797 kl 61.321 9:11:56.504067
130/200 6015 loss: total 189.675 reconstruction 128.311 kl 61.364 val loss total 190.378 reconstruction 128.996 kl 61.382 9:16:23.939981
131/200 6015 loss: total 189.654 reconstruction 128.273 kl 61.381 val loss total 189.999 reconstruction 128.685 kl 61.314 9:20:49.017514
132/200 6015 loss: total 189.576 reconstruction 128.225 kl 61.351 val loss total 190.130 reconstruction 128.922 kl 61.208 9:25:13.666311
133/200 6015 loss: total 189.620 reconstruction 128.234 kl 61.386 val loss total 190.300 reconstruction 128.988 kl 61.311 9:29:39.861686
134/200 6015 loss: total 189.673 reconstruction 128.252 kl 61.422 val loss total 190.170 reconstruction 128.922 kl 61.249 9:34:04.696124
135/200 6015 loss: total 189.756 reconstruction 128.355 kl 61.401 val loss total 190.181 reconstruction 129.038 kl 61.143 9:38:30.420747
136/200 6015 loss: total 189.549 reconstruction 128.184 kl 61.365 val loss total 190.169 reconstruction 128.710 kl 61.458 9:42:56.187432
137/200 6015 loss: total 189.646 reconstruction 128.251 kl 61.395 val loss total 190.084 reconstruction 128.901 kl 61.183 9:47:21.540988
138/200 6015 loss: total 189.644 reconstruction 128.266 kl 61.378 val loss total 190.178 reconstruction 128.844 kl 61.335 9:51:47.315635
139/200 6015 loss: total 189.732 reconstruction 128.359 kl 61.373 val loss total 190.086 reconstruction 128.844 kl 61.242 9:56:13.596368
140/200 6015 loss: total 189.496 reconstruction 128.096 kl 61.400 val loss total 190.208 reconstruction 128.915 kl 61.293 10:00:41.678795
141/200 6015 loss: total 189.738 reconstruction 128.316 kl 61.423 val loss total 190.011 reconstruction 128.747 kl 61.264 10:05:09.859723
142/200 6015 loss: total 189.686 reconstruction 128.323 kl 61.363 val loss total 190.181 reconstruction 128.947 kl 61.234 10:09:36.739844
143/200 6015 loss: total 189.505 reconstruction 128.126 kl 61.379 val loss total 190.168 reconstruction 128.869 kl 61.299 10:14:03.030662
144/200 6015 loss: total 189.709 reconstruction 128.312 kl 61.398 val loss total 190.200 reconstruction 129.009 kl 61.191 10:18:28.746619
145/200 6015 loss: total 189.527 reconstruction 128.183 kl 61.344 val loss total 190.090 reconstruction 128.834 kl 61.256 10:22:55.068533
146/200 6015 loss: total 189.729 reconstruction 128.329 kl 61.400 val loss total 190.309 reconstruction 128.995 kl 61.314 10:27:21.211608
147/200 6015 loss: total 189.608 reconstruction 128.226 kl 61.382 val loss total 190.137 reconstruction 128.905 kl 61.232 10:31:46.270507
148/200 6015 loss: total 189.633 reconstruction 128.240 kl 61.394 val loss total 190.203 reconstruction 128.914 kl 61.289 10:36:11.594179
149/200 6015 loss: total 189.677 reconstruction 128.249 kl 61.428 val loss total 190.265 reconstruction 129.061 kl 61.204 10:40:37.343848
150/200 6015 loss: total 189.783 reconstruction 128.391 kl 61.392 val loss total 190.105 reconstruction 128.813 kl 61.291 10:45:04.091802
151/200 6015 loss: total 189.479 reconstruction 128.100 kl 61.380 val loss total 190.205 reconstruction 128.910 kl 61.295 10:49:29.456093
152/200 6015 loss: total 189.696 reconstruction 128.318 kl 61.378 val loss total 190.137 reconstruction 128.851 kl 61.286 10:53:55.848796
153/200 6015 loss: total 189.544 reconstruction 128.173 kl 61.371 val loss total 190.181 reconstruction 128.876 kl 61.305 10:58:22.135650
154/200 6015 loss: total 189.651 reconstruction 128.270 kl 61.381 val loss total 190.102 reconstruction 128.826 kl 61.277 11:02:49.476577
155/200 6015 loss: total 189.680 reconstruction 128.284 kl 61.396 val loss total 190.244 reconstruction 128.941 kl 61.303 11:07:15.140242
156/200 6015 loss: total 189.613 reconstruction 128.214 kl 61.399 val loss total 190.251 reconstruction 128.912 kl 61.338 11:11:41.008423
157/200 6015 loss: total 189.596 reconstruction 128.225 kl 61.371 val loss total 190.091 reconstruction 128.781 kl 61.310 11:16:07.700525
158/200 6015 loss: total 189.639 reconstruction 128.233 kl 61.406 val loss total 190.231 reconstruction 128.991 kl 61.240 11:20:34.165837
159/200 6015 loss: total 189.727 reconstruction 128.329 kl 61.399 val loss total 190.109 reconstruction 128.867 kl 61.242 11:25:00.990971
160/200 6015 loss: total 189.471 reconstruction 128.086 kl 61.385 val loss total 190.093 reconstruction 128.818 kl 61.275 11:29:26.262069
161/200 6015 loss: total 189.589 reconstruction 128.203 kl 61.386 val loss total 190.223 reconstruction 128.925 kl 61.298 11:33:52.204276
162/200 6015 loss: total 189.805 reconstruction 128.393 kl 61.411 val loss total 190.219 reconstruction 128.899 kl 61.320 11:38:18.429700
163/200 6015 loss: total 189.500 reconstruction 128.122 kl 61.379 val loss total 190.115 reconstruction 128.865 kl 61.250 11:42:44.246056
164/200 6015 loss: total 189.811 reconstruction 128.397 kl 61.414 val loss total 190.195 reconstruction 128.948 kl 61.247 11:47:09.796045
165/200 6015 loss: total 189.579 reconstruction 128.199 kl 61.380 val loss total 190.067 reconstruction 128.765 kl 61.302 11:51:35.220580
166/200 6015 loss: total 189.522 reconstruction 128.122 kl 61.400 val loss total 190.110 reconstruction 128.816 kl 61.294 11:56:00.472572
167/200 6015 loss: total 189.640 reconstruction 128.246 kl 61.395 val loss total 190.187 reconstruction 128.856 kl 61.331 12:00:26.833856
168/200 6015 loss: total 189.681 reconstruction 128.276 kl 61.405 val loss total 189.991 reconstruction 128.711 kl 61.280 12:04:53.200078
169/200 6015 loss: total 189.543 reconstruction 128.172 kl 61.371 val loss total 190.168 reconstruction 128.941 kl 61.227 12:09:19.252857
170/200 6015 loss: total 189.637 reconstruction 128.245 kl 61.393 val loss total 190.111 reconstruction 128.833 kl 61.279 12:13:45.338741
171/200 6015 loss: total 189.695 reconstruction 128.274 kl 61.421 val loss total 190.020 reconstruction 128.760 kl 61.259 12:18:11.096868
172/200 6015 loss: total 189.734 reconstruction 128.340 kl 61.395 val loss total 190.145 reconstruction 128.897 kl 61.248 12:22:36.571231
173/200 6015 loss: total 189.451 reconstruction 128.114 kl 61.336 val loss total 190.161 reconstruction 128.903 kl 61.258 12:27:01.984187
174/200 6015 loss: total 189.819 reconstruction 128.376 kl 61.442 val loss total 190.065 reconstruction 128.763 kl 61.302 12:31:27.095194
175/200 6015 loss: total 189.454 reconstruction 128.088 kl 61.366 val loss total 190.097 reconstruction 128.828 kl 61.269 12:35:52.706185
176/200 6015 loss: total 189.638 reconstruction 128.234 kl 61.403 val loss total 190.098 reconstruction 128.817 kl 61.281 12:40:19.123477
177/200 6015 loss: total 189.670 reconstruction 128.251 kl 61.419 val loss total 190.039 reconstruction 128.834 kl 61.204 12:44:45.313097
178/200 6015 loss: total 189.608 reconstruction 128.245 kl 61.364 val loss total 190.212 reconstruction 128.949 kl 61.263 12:49:11.985482
179/200 6015 loss: total 189.719 reconstruction 128.299 kl 61.420 val loss total 190.074 reconstruction 128.851 kl 61.223 12:53:38.367333
180/200 6015 loss: total 189.574 reconstruction 128.197 kl 61.377 val loss total 190.127 reconstruction 128.842 kl 61.285 12:58:04.289879
181/200 6015 loss: total 189.650 reconstruction 128.251 kl 61.399 val loss total 190.135 reconstruction 128.807 kl 61.328 13:02:29.627651
182/200 6015 loss: total 189.617 reconstruction 128.227 kl 61.390 val loss total 190.084 reconstruction 128.766 kl 61.318 13:06:54.999391
183/200 6015 loss: total 189.515 reconstruction 128.142 kl 61.373 val loss total 190.108 reconstruction 128.823 kl 61.285 13:11:20.206714
184/200 6015 loss: total 189.682 reconstruction 128.264 kl 61.417 val loss total 190.180 reconstruction 128.883 kl 61.297 13:15:45.661520
185/200 6015 loss: total 189.644 reconstruction 128.242 kl 61.402 val loss total 190.139 reconstruction 128.830 kl 61.309 13:20:11.882997
186/200 6015 loss: total 189.697 reconstruction 128.282 kl 61.415 val loss total 190.172 reconstruction 128.896 kl 61.276 13:24:39.363664
187/200 6015 loss: total 189.573 reconstruction 128.177 kl 61.396 val loss total 189.989 reconstruction 128.672 kl 61.317 13:29:05.981900
188/200 6015 loss: total 189.670 reconstruction 128.283 kl 61.387 val loss total 190.174 reconstruction 128.943 kl 61.231 13:33:31.592366
189/200 6015 loss: total 189.592 reconstruction 128.211 kl 61.381 val loss total 190.123 reconstruction 128.832 kl 61.291 13:37:57.034574
190/200 6015 loss: total 189.610 reconstruction 128.194 kl 61.416 val loss total 190.238 reconstruction 128.945 kl 61.293 13:42:21.798357
191/200 6015 loss: total 189.566 reconstruction 128.175 kl 61.391 val loss total 190.202 reconstruction 128.905 kl 61.297 13:46:46.394935
192/200 6015 loss: total 189.711 reconstruction 128.323 kl 61.388 val loss total 190.123 reconstruction 128.840 kl 61.283 13:51:11.165649
193/200 6015 loss: total 189.568 reconstruction 128.184 kl 61.384 val loss total 190.100 reconstruction 128.798 kl 61.302 13:55:35.786728
194/200 6015 loss: total 189.641 reconstruction 128.222 kl 61.419 val loss total 190.077 reconstruction 128.811 kl 61.266 14:00:00.058804
195/200 6015 loss: total 189.748 reconstruction 128.353 kl 61.394 val loss total 190.079 reconstruction 128.791 kl 61.288 14:04:24.774816
196/200 6015 loss: total 189.413 reconstruction 128.065 kl 61.349 val loss total 190.087 reconstruction 128.803 kl 61.285 14:08:49.015329
197/200 6015 loss: total 189.734 reconstruction 128.309 kl 61.425 val loss total 190.088 reconstruction 128.804 kl 61.284 14:13:15.174220
198/200 6015 loss: total 189.586 reconstruction 128.189 kl 61.397 val loss total 190.088 reconstruction 128.803 kl 61.286 14:17:40.344938
199/200 6015 loss: total 189.759 reconstruction 128.366 kl 61.393 val loss total 190.105 reconstruction 128.818 kl 61.287 14:22:04.982790
200/200 6015 loss: total 189.598 reconstruction 128.187 kl 61.410 val loss total 190.061 reconstruction 128.790 kl 61.271 14:26:30.451611
In [53]:
loss3_2 = log3_2['loss']
rloss3_2 = log3_2['reconstruction_loss']
kloss3_2 = log3_2['kl_loss']
val_loss3_2 = log3_2['val_loss']
val_rloss3_2 = log3_2['val_reconstruction_loss']
val_kloss3_2 = log3_2['val_kl_loss']
In [54]:
loss3 = np.concatenate([loss3_1, loss3_2], axis=0)
rloss3 = np.concatenate([rloss3_1, rloss3_2], axis=0)
kloss3 = np.concatenate([kloss3_1, kloss3_2], axis=0)

val_loss3 = np.concatenate([val_loss3_1, val_loss3_2], axis=0)
val_rloss3 = np.concatenate([val_rloss3_1, val_rloss3_2], axis=0)
val_kloss3 = np.concatenate([val_kloss3_1, val_kloss3_2], axis=0)
In [55]:
VariationalAutoEncoder.plot_history(
    [loss3, val_loss3], 
    ['total_loss', 'val_total_loss']
)
In [56]:
VariationalAutoEncoder.plot_history(
    [rloss3, val_rloss3], 
    ['reconstruction_loss', 'val_reconstruction_loss']
)
In [57]:
VariationalAutoEncoder.plot_history(
    [kloss3, val_kloss3], 
    ['kl_loss', 'val_kl_loss']
)
In [58]:
x_, _ = next(val_data_flow)
selected_images = x_[:10]
In [60]:
z_mean3, z_log_var3, z3 = vae3.encoder(selected_images)
reconst_images3 = vae3.decoder(z3).numpy()  # decoder() returns Tensor for @tf.function declaration. Convert the Tensor to numpy array.
txts3 = [f'{p[0]:.3f}, {p[1]:.3f}' for p in z3 ]
In [62]:
%matplotlib inline

VariationalAutoEncoder.showImages(selected_images, reconst_images3, txts3, 1.4, 1.4)

Save the loss transition for future training.

Save the loss transition to the file 'loss_N.pkl'.

将来の学習のために、loss の変遷をセーブしておく

「(3) 学習」のlossの変遷を 'loss_N.pkl' ファイルにセーブしておく。

In [63]:
# Save loss variables for future training
# 将来の学習のために loss 変数をセーブしておく
import os
import pickle

var_path = f'{save_path3}/loss_{vae3.epoch-1}.pkl'

dpath, fname = os.path.split(var_path)
if dpath != '' and not os.path.exists(dpath):
    os.makedirs(dpath)

with open(var_path, 'wb') as f:
    pickle.dump([
        loss3, 
        rloss3, 
        kloss3, 
        val_loss3, 
        val_rloss3, 
        val_kloss3         
    ], f)
In [ ]: