Updated 19/Nov/2021 by Yoshihisa Nitta  

AutoEncoder Training for MNIST dataset with Tensorflow 2 on Google Colab

MNISTデータセットに対して AutoEncoder を Google Colab 上で Tensorflow 2 で訓練する

In [ ]:
#! pip install tensorflow==2.7.0
In [ ]:
%tensorflow_version 2.x

import tensorflow as tf
print(tf.__version__)
2.7.0

AutoEncoder

  • Diederik P Kngma, Max Welling: Auto-Encoding Variational Bayes, 2013
    https://arxiv.org/abs/1312.6114
  • 「難しい事後分布を持つ連続潜在変数が存在し、さらにデータセットが大きい場合に、有効確率モデルを用いて効果的な推論と学習をどうやって行えばよいのだろうか」という問題がある。 この論文では、ある穏やかな(mild)微分可能条件下では適用できる、確率的変分推論と学習アルゴリズムを紹介する。 貢献する点は2点である。 「変分下限の再パラメータ化により、普通のSGDを用いてtraining可能な下限推定量が得られる」 「提案する下限推定器を用いて近似推論モデルを学習することによって、データポイント毎の連続潜在変数を持つ i.i.d データセットに対して事後推定が効率的に実行できる。」

Check the execution environment on Google Colab

Google Colab 上の実行環境を確認する

In [ ]:
! nvidia-smi
! cat /proc/cpuinfo
! cat /etc/issue
! free -h
Mon Nov 22 06:21:23 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 79
model name	: Intel(R) Xeon(R) CPU @ 2.20GHz
stepping	: 0
microcode	: 0x1
cpu MHz		: 2199.998
cache size	: 56320 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 0
initial apicid	: 0
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4399.99
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 bits virtual
power management:

processor	: 1
vendor_id	: GenuineIntel
cpu family	: 6
model		: 79
model name	: Intel(R) Xeon(R) CPU @ 2.20GHz
stepping	: 0
microcode	: 0x1
cpu MHz		: 2199.998
cache size	: 56320 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 1
initial apicid	: 1
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4399.99
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 bits virtual
power management:

Ubuntu 18.04.5 LTS \n \l

              total        used        free      shared  buff/cache   available
Mem:            12G        748M        9.9G        1.2M        2.0G         11G
Swap:            0B          0B          0B

Mount Google Drive from Google Colab

Google Colab から GoogleDrive をマウントする

In [ ]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
In [ ]:
! ls /content/drive
MyDrive  Shareddrives

Download the soure file from Google Drive or nw.tsuda.ac.jp

Basically, gdown from Google Drive. Download from nw.tsuda.ac.jp above only if the specifications of Google Drive change and you cannot download from Google Drive.

Google Drive または nw.tsuda.ac.jp からファイルをダウンロードする

基本的に Google Drive から gdown してください。 Google Drive の仕様が変わってダウンロードができない場合にのみ、nw.tsuda.ac.jp からダウンロードしてください。

In [ ]:
# Download source file
nw_path = './nw'
! rm -rf {nw_path}
! mkdir -p {nw_path}

if True: # from Google Drive
    url_model = 'https://drive.google.com/uc?id=1ZDgWE7wmVwG_ZuQVUjuh_XHeIO-7Yn63'
    ! (cd {nw_path}; gdown {url_model})
else:    # from nw.tsuda.ac.jp
    URL_NW = 'https://nw.tsuda.ac.jp/lec/GoogleColab/pub'
    url_model = f'{URL_NW}/models/AutoEncoder.py'
    ! wget -nd {url_model} -P {nw_path}     # download to './nw/AutoEncoder.py'
Downloading...
From: https://drive.google.com/uc?id=1ZDgWE7wmVwG_ZuQVUjuh_XHeIO-7Yn63
To: /content/nw/AutoEncoder.py
100% 13.9k/13.9k [00:00<00:00, 21.5MB/s]
In [ ]:
!cat {nw_path}/AutoEncoder.py
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import os
import pickle
import datetime

class AutoEncoder():
    def __init__(self, 
                 input_dim,
                 encoder_conv_filters,
                 encoder_conv_kernel_size,
                 encoder_conv_strides,
                 decoder_conv_t_filters,
                 decoder_conv_t_kernel_size,
                 decoder_conv_t_strides,
                 z_dim,
                 use_batch_norm = False,
                 use_dropout = False,
                 epoch = 0
    ):
        self.name = 'autoencoder'
        self.input_dim = input_dim
        self.encoder_conv_filters = encoder_conv_filters
        self.encoder_conv_kernel_size = encoder_conv_kernel_size
        self.encoder_conv_strides = encoder_conv_strides
        self.decoder_conv_t_filters = decoder_conv_t_filters
        self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size
        self.decoder_conv_t_strides = decoder_conv_t_strides
        self.z_dim = z_dim
        
        self.use_batch_norm = use_batch_norm
        self.use_dropout = use_dropout

        self.epoch = epoch
        
        self.n_layers_encoder = len(encoder_conv_filters)
        self.n_layers_decoder = len(decoder_conv_t_filters)
        
        self._build()
 

    def _build(self):
        ### THE ENCODER
        encoder_input = tf.keras.layers.Input(shape=self.input_dim, name='encoder_input')
        x = encoder_input
    
        for i in range(self.n_layers_encoder):
            x = tf.keras.layers.Conv2D(
                filters = self.encoder_conv_filters[i],
                kernel_size = self.encoder_conv_kernel_size[i],
                strides = self.encoder_conv_strides[i],
                padding  = 'same',
                name = 'encoder_conv_' + str(i)
            )(x)
            x = tf.keras.layers.LeakyReLU()(x)
            if self.use_batch_norm:
                x = tf.keras.layers.BatchNormalization()(x)
            if self.use_dropout:
                x = tf.keras.layers.Dropout(rate = 0.25)(x)
              
        shape_before_flattening = tf.keras.backend.int_shape(x)[1:] # shape for 1 data
        
        x = tf.keras.layers.Flatten()(x)
        encoder_output = tf.keras.layers.Dense(self.z_dim, name='encoder_output')(x)
        
        self.encoder = tf.keras.models.Model(encoder_input, encoder_output)
        
        ### THE DECODER
        decoder_input = tf.keras.layers.Input(shape=(self.z_dim,), name='decoder_input')
        x = tf.keras.layers.Dense(np.prod(shape_before_flattening))(decoder_input)
        x = tf.keras.layers.Reshape(shape_before_flattening)(x)
        
        for i in range(self.n_layers_decoder):
            x =   tf.keras.layers.Conv2DTranspose(
                filters = self.decoder_conv_t_filters[i],
                kernel_size = self.decoder_conv_t_kernel_size[i],
                strides = self.decoder_conv_t_strides[i],
                padding = 'same',
                name = 'decoder_conv_t_' + str(i)
            )(x)
          
            if i < self.n_layers_decoder - 1:
                x = tf.keras.layers.LeakyReLU()(x)
                if self.use_batch_norm:
                    x = tf.keras.layers.BatchNormalization()(x)
                if self.use_dropout:
                    x = tf.keras.layers.Dropout(rate=0.25)(x)
            else:
                x = tf.keras.layers.Activation('sigmoid')(x)
              
        decoder_output = x
        self.decoder = tf.keras.models.Model(decoder_input, decoder_output)
        
        ### THE FULL AUTOENCODER
        model_input = encoder_input
        model_output = self.decoder(encoder_output)
        
        self.model = tf.keras.models.Model(model_input, model_output)


    def save(self, folder):
        self.save_params(os.path.join(folder, 'params.pkl'))
        self.save_weights(os.path.join(folder, 'weights/weights.h5'))


    @staticmethod
    def load(folder, epoch=None):   # AutoEncoder.load(folder)
        params = AutoEncoder.load_params(os.path.join(folder, 'params.pkl'))
        AE = AutoEncoder(*params)
        if epoch is None:
            AE.model.load_weights(os.path.join(folder, 'weights/weights.h5'))
        else:
            AE.model.load_weights(os.path.join(folder, f'weights/weights_{epoch-1}.h5'))
            AE.epoch = epoch

        return AE


    def save_params(self, filepath):
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        with open(filepath, 'wb') as f:
            pickle.dump([
                self.input_dim,
                self.encoder_conv_filters,
                self.encoder_conv_kernel_size,
                self.encoder_conv_strides,
                self.decoder_conv_t_filters,
                self.decoder_conv_t_kernel_size,
                self.decoder_conv_t_strides,
                self.z_dim,
                self.use_batch_norm,
                self.use_dropout,
                self.epoch
            ], f)


    @staticmethod
    def load_params(filepath):
        with open(filepath, 'rb') as f:
            params = pickle.load(f)
        return params


    def save_weights(self, filepath):
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        self.model.save_weights(filepath)
        
        
    def load_weights(self, filepath):
        self.model.load_weights(filepath)


    def save_images(self, imgs, filepath):
        z_points = self.encoder.predict(imgs)
        reconst_imgs = self.decoder.predict(z_points)
        txts = [ f'{p[0]:.3f}, {p[1]:.3f}' for p in z_points ]
        AutoEncoder.showImages(imgs, reconst_imgs, txts, 1.4, 1.4, 0.5, filepath)
      

    @staticmethod
    def r_loss(y_true, y_pred):
        return tf.keras.backend.mean(tf.keras.backend.square(y_true - y_pred), axis=[1,2,3])


    def compile(self, learning_rate):
        self.learning_rate = learning_rate
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
        self.model.compile(optimizer=optimizer, loss = AutoEncoder.r_loss)

        
    def train_with_fit(self,
               x_train,
               y_train,
               batch_size,
               epochs,
               run_folder='run/',
               validation_data=None
    ):
        history= self.model.fit(
            x_train,
            y_train,
            batch_size = batch_size,
            shuffle = True,
            initial_epoch = self.epoch,
            epochs = epochs,
            validation_data = validation_data
        )
        if self.epoch < epochs:
            self.epoch = epochs

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(os.path.join(run_folder,f'weights/weights_{self.epoch-1}.h5'))
            #idxs = np.random.choice(len(x_train), 10)
            #self.save_images(x_train[idxs], os.path.join(run_folder, f'images/image_{self.epoch-1}.png'))

        return history
        
        
    def train(self,
               x_train,
               y_train,
               batch_size = 32,
               epochs = 10,
               shuffle=False,
               run_folder='run/',
               optimizer=None,
               save_epoch_interval=100,
               validation_data = None
    ):
        start_time = datetime.datetime.now()
        steps = x_train.shape[0] // batch_size

        losses = []
        val_losses = []

        for epoch in range(self.epoch, epochs):
            epoch_loss = 0
            indices = tf.range(x_train.shape[0], dtype=tf.int32)
            if shuffle:
                indices = tf.random.shuffle(indices)
            x_ = x_train[indices]
            y_ = y_train[indices]
            
            for step in range(steps):
                start = batch_size * step
                end = start + batch_size

                with tf.GradientTape() as tape:
                    outputs = self.model(x_[start:end])
                    tmp_loss = AutoEncoder.r_loss(y_[start:end], outputs)

                grads = tape.gradient(tmp_loss, self.model.trainable_variables)
                optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

            epoch_loss = np.mean(tmp_loss)
            losses.append(epoch_loss)

            val_str = ''
            if validation_data != None:
                x_val, y_val = validation_data
                outputs_val = self.model(x_val)
                val_loss = np.mean(AutoEncoder.r_loss(y_val, outputs_val))
                val_str = f'val loss: {val_loss:.4f}  '
                val_losses.append(val_loss)


            if (epoch+1) % save_epoch_interval == 0 and run_folder != None:
                self.save(run_folder)
                self.save_weights(os.path.join(run_folder,f'weights/weights_{self.epoch}.h5'))
                #idxs = np.random.choice(len(x_train), 10)
                #self.save_images(x_train[idxs], os.path.join(run_folder, f'images/image_{self.epoch}.png'))

            elapsed_time = datetime.datetime.now() - start_time
            print(f'{epoch+1}/{epochs} {steps} loss: {epoch_loss:.4f}  {val_str}{elapsed_time}')

            self.epoch += 1

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(os.path.join(run_folder,f'weights/weights_{self.epoch-1}.h5'))
            #idxs = np.random.choice(len(x_train), 10)
            #self.save_images(x_train[idxs], os.path.join(run_folder, f'images/image_{self.epoch-1}.png'))

        return losses, val_losses

    @staticmethod
    @tf.function
    def compute_loss_and_grads(model,x,y):
        with tf.GradientTape() as tape:
            outputs = model(x)
            tmp_loss = AutoEncoder.r_loss(y,outputs)
        grads = tape.gradient(tmp_loss, model.trainable_variables)
        return tmp_loss, grads


    def train_tf(self,
               x_train,
               y_train,
               batch_size = 32,
               epochs = 10,
               shuffle=False,
               run_folder='run/',
               optimizer=None,
               save_epoch_interval=100,
               validation_data = None
    ):
        start_time = datetime.datetime.now()
        steps = x_train.shape[0] // batch_size

        losses = []
        val_losses = []

        for epoch in range(self.epoch, epochs):
            epoch_loss = 0
            indices = tf.range(x_train.shape[0], dtype=tf.int32)
            if shuffle:
                indices = tf.random.shuffle(indices)
            x_ = x_train[indices]
            y_ = y_train[indices]

            step_losses = []
            for step in range(steps):
                start = batch_size * step
                end = start + batch_size

                tmp_loss, grads = AutoEncoder.compute_loss_and_grads(self.model, x_[start:end], y_[start:end])
                optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

                step_losses.append(np.mean(tmp_loss))

            epoch_loss = np.mean(step_losses)
            losses.append(epoch_loss)

            val_str = ''
            if validation_data != None:
                x_val, y_val = validation_data
                outputs_val = self.model(x_val)
                val_loss = np.mean(AutoEncoder.r_loss(y_val, outputs_val))
                val_str = f'val loss: {val_loss:.4f}  '
                val_losses.append(val_loss)


            if (epoch+1) % save_epoch_interval == 0 and run_folder != None:
                self.save(run_folder)
                self.save_weights(os.path.join(run_folder,f'weights/weights_{self.epoch}.h5'))
                #idxs = np.random.choice(len(x_train), 10)
                #self.save_images(x_train[idxs], os.path.join(run_folder, f'images/image_{self.epoch}.png'))

            elapsed_time = datetime.datetime.now() - start_time
            print(f'{epoch+1}/{epochs} {steps} loss: {epoch_loss:.4f}  {val_str}{elapsed_time}')

            self.epoch += 1

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(os.path.join(run_folder,f'weights/weights_{self.epoch-1}.h5'))
            #idxs = np.random.choice(len(x_train), 10)
            #self.save_images(x_train[idxs], os.path.join(run_folder, f'images/image_{self.epoch-1}.png'))

        return losses, val_losses


    @staticmethod
    def showImages(imgs1, imgs2, txts, w, h, vskip=0.5, filepath=None):
        n = len(imgs1)
        fig, ax = plt.subplots(2, n, figsize=(w * n, (2+vskip) * h))
        for i in range(n):
            if n == 1:
                axis = ax[0]
            else:
                axis = ax[0][i]
            img = imgs1[i].squeeze()
            axis.imshow(img, cmap='gray_r')
            axis.axis('off')

            axis.text(0.5, -0.35, txts[i], fontsize=10, ha='center', transform=axis.transAxes)

            if n == 1:
                axis = ax[1]
            else:
                axis = ax[1][i]
            img2 = imgs2[i].squeeze()
            axis.imshow(img2, cmap='gray_r')
            axis.axis('off')

        if not filepath is None:
            dpath, fname = os.path.split(filepath)
            if dpath != '' and not os.path.exists(dpath):
                os.makedirs(dpath)
            fig.savefig(filepath, dpi=600)
            plt.close()
        else:
            plt.show()

    @staticmethod
    def plot_history(vals, labels):
        colors = ['red', 'blue', 'green', 'orange', 'black']
        n = len(vals)
        fig, ax = plt.subplots(1, 1, figsize=(9,4))
        for i in range(n):
            ax.plot(vals[i], c=colors[i], label=labels[i])
        ax.legend(loc='upper right')
        ax.set_xlabel('epochs')
        # ax[0].set_ylabel('loss')
        
        plt.show()

Preparing the MNIST datasets

MNIST データセットを用意する

In [ ]:
import tensorflow as tf
import numpy as np
In [ ]:
# MNIST datasets
(x_train_raw, y_train_raw), (x_test_raw, y_test_raw) = tf.keras.datasets.mnist.load_data()
print(x_train_raw.shape)
print(y_train_raw.shape)
print(x_test_raw.shape)
print(y_test_raw.shape)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)
In [ ]:
x_train = x_train_raw.reshape(x_train_raw.shape+(1,)).astype('float32') / 255.0
x_test = x_test_raw.reshape(x_test_raw.shape+(1,)).astype('float32') / 255.0
print(x_train.shape)
print(x_test.shape)
(60000, 28, 28, 1)
(10000, 28, 28, 1)

Define the Neural Network Model

Use the AutoEncoder class downloaded from nw.tsuda.ac.jp.

ニューラルネットワーク・モデル の定義

nw.tsuda.ac.jp からダウンロードした AutoEncoder クラスを使う。

In [ ]:
from nw.AutoEncoder import AutoEncoder

AE = AutoEncoder(
    input_dim = (28, 28, 1),
    encoder_conv_filters = [32, 64, 64, 64],
    encoder_conv_kernel_size = [3, 3, 3, 3],
    encoder_conv_strides = [1, 2, 2, 1],
    decoder_conv_t_filters = [64, 64, 32, 1],
    decoder_conv_t_kernel_size = [3, 3, 3, 3],
    decoder_conv_t_strides = [1, 2, 2, 1],
    z_dim = 2
)
In [ ]:
AE.encoder.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 encoder_input (InputLayer)  [(None, 28, 28, 1)]       0         
                                                                 
 encoder_conv_0 (Conv2D)     (None, 28, 28, 32)        320       
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 28, 28, 32)        0         
                                                                 
 encoder_conv_1 (Conv2D)     (None, 14, 14, 64)        18496     
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 14, 14, 64)        0         
                                                                 
 encoder_conv_2 (Conv2D)     (None, 7, 7, 64)          36928     
                                                                 
 leaky_re_lu_2 (LeakyReLU)   (None, 7, 7, 64)          0         
                                                                 
 encoder_conv_3 (Conv2D)     (None, 7, 7, 64)          36928     
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 7, 7, 64)          0         
                                                                 
 flatten (Flatten)           (None, 3136)              0         
                                                                 
 encoder_output (Dense)      (None, 2)                 6274      
                                                                 
=================================================================
Total params: 98,946
Trainable params: 98,946
Non-trainable params: 0
_________________________________________________________________
In [ ]:
AE.decoder.summary()
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 decoder_input (InputLayer)  [(None, 2)]               0         
                                                                 
 dense (Dense)               (None, 3136)              9408      
                                                                 
 reshape (Reshape)           (None, 7, 7, 64)          0         
                                                                 
 decoder_conv_t_0 (Conv2DTra  (None, 7, 7, 64)         36928     
 nspose)                                                         
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 7, 7, 64)          0         
                                                                 
 decoder_conv_t_1 (Conv2DTra  (None, 14, 14, 64)       36928     
 nspose)                                                         
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 14, 14, 64)        0         
                                                                 
 decoder_conv_t_2 (Conv2DTra  (None, 28, 28, 32)       18464     
 nspose)                                                         
                                                                 
 leaky_re_lu_6 (LeakyReLU)   (None, 28, 28, 32)        0         
                                                                 
 decoder_conv_t_3 (Conv2DTra  (None, 28, 28, 1)        289       
 nspose)                                                         
                                                                 
 activation (Activation)     (None, 28, 28, 1)         0         
                                                                 
=================================================================
Total params: 102,017
Trainable params: 102,017
Non-trainable params: 0
_________________________________________________________________

Training the Neural Model

Try the training in 3 ways.

With each way, you first train a few times and save the state to some files. Then, after loading the saved states, further training proceeds.

ニューラルモデルを学習する

3通りの方法で学習を試みる。 どの方法においても、まず数回学習を進めて、状態をファイルに保存する。 そして、保存した状態をロードしてから、さらに学習を進める。

In [ ]:
MAX_EPOCHS = 200
In [ ]:
learning_rate = 0.0005

(1) Simple Training with fit()

Instead of using callbacks, simply train using fit() function.

(1) fit() 関数を使った単純なTraining

callbackは使わずに、単純にfit()を使ってtrainingしてみる。

In [ ]:
save_path1 = '/content/drive/MyDrive/ColabRun/AE01'
In [ ]:
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
AE.model.compile(optimizer=optimizer, loss=AutoEncoder.r_loss)
In [ ]:
# At first, train for a few epochs.
# まず、少ない回数 training してみる

history=AE.train_with_fit(
    x_train,
    x_train,
    batch_size=32,
    epochs = 3,
    run_folder = save_path1,
    validation_data = (x_test, x_test)
)
Epoch 1/3
1875/1875 [==============================] - 27s 6ms/step - loss: 0.0550 - val_loss: 0.0487
Epoch 2/3
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0464 - val_loss: 0.0449
Epoch 3/3
1875/1875 [==============================] - 10s 6ms/step - loss: 0.0444 - val_loss: 0.0439
In [ ]:
print(history.history)
{'loss': [0.05495461821556091, 0.0464060977101326, 0.04438251256942749], 'val_loss': [0.04873434454202652, 0.04490825906395912, 0.043926868587732315]}
In [ ]:
# Load the trained states saved before
# 保存されている学習結果をロードする

AE_work = AutoEncoder.load(save_path1)

# display the epoch count of training
# training のepoch回数を表示する
print(AE_work.epoch)
3
In [ ]:
# Then, train for more epochs. The training continues from the current self.epoch to the epoches specified.
# 追加でtrainingする。保存されている現在のepoch数から始めて、指定したepochs までtrainingが進む。

AE_work.model.compile(optimizer, loss=AutoEncoder.r_loss)

history_work = AE_work.train_with_fit(
    x_train,
    x_train,
    batch_size=32,
    epochs=MAX_EPOCHS,
    run_folder = save_path1,
    validation_data=(x_test, x_test)
)
Epoch 4/200
1875/1875 [==============================] - 11s 5ms/step - loss: 0.0439 - val_loss: 0.0428
Epoch 5/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0422 - val_loss: 0.0421
Epoch 6/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0417 - val_loss: 0.0414
Epoch 7/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0412 - val_loss: 0.0412
Epoch 8/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0409 - val_loss: 0.0409
Epoch 9/200
1875/1875 [==============================] - 10s 6ms/step - loss: 0.0406 - val_loss: 0.0411
Epoch 10/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0404 - val_loss: 0.0405
Epoch 11/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0402 - val_loss: 0.0403
Epoch 12/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0401 - val_loss: 0.0402
Epoch 13/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0399 - val_loss: 0.0402
Epoch 14/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0398 - val_loss: 0.0399
Epoch 15/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0396 - val_loss: 0.0401
Epoch 16/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0395 - val_loss: 0.0406
Epoch 17/200
1875/1875 [==============================] - 10s 6ms/step - loss: 0.0394 - val_loss: 0.0399
Epoch 18/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0393 - val_loss: 0.0400
Epoch 19/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0392 - val_loss: 0.0395
Epoch 20/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0391 - val_loss: 0.0393
Epoch 21/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0390 - val_loss: 0.0393
Epoch 22/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0390 - val_loss: 0.0397
Epoch 23/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0389 - val_loss: 0.0394
Epoch 24/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0388 - val_loss: 0.0395
Epoch 25/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0388 - val_loss: 0.0393
Epoch 26/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0387 - val_loss: 0.0398
Epoch 27/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0387 - val_loss: 0.0395
Epoch 28/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0386 - val_loss: 0.0395
Epoch 29/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0385 - val_loss: 0.0390
Epoch 30/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0385 - val_loss: 0.0391
Epoch 31/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0384 - val_loss: 0.0395
Epoch 32/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0384 - val_loss: 0.0391
Epoch 33/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0383 - val_loss: 0.0394
Epoch 34/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0383 - val_loss: 0.0390
Epoch 35/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0382 - val_loss: 0.0393
Epoch 36/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0382 - val_loss: 0.0391
Epoch 37/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0381 - val_loss: 0.0390
Epoch 38/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0381 - val_loss: 0.0391
Epoch 39/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0381 - val_loss: 0.0388
Epoch 40/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0380 - val_loss: 0.0392
Epoch 41/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0380 - val_loss: 0.0394
Epoch 42/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0379 - val_loss: 0.0389
Epoch 43/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0379 - val_loss: 0.0392
Epoch 44/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0379 - val_loss: 0.0393
Epoch 45/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0378 - val_loss: 0.0390
Epoch 46/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0378 - val_loss: 0.0389
Epoch 47/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0378 - val_loss: 0.0392
Epoch 48/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0377 - val_loss: 0.0390
Epoch 49/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0377 - val_loss: 0.0386
Epoch 50/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0376 - val_loss: 0.0391
Epoch 51/200
1875/1875 [==============================] - 10s 6ms/step - loss: 0.0377 - val_loss: 0.0392
Epoch 52/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0376 - val_loss: 0.0391
Epoch 53/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0376 - val_loss: 0.0385
Epoch 54/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0375 - val_loss: 0.0386
Epoch 55/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0375 - val_loss: 0.0386
Epoch 56/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0375 - val_loss: 0.0387
Epoch 57/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0375 - val_loss: 0.0385
Epoch 58/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0375 - val_loss: 0.0386
Epoch 59/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0374 - val_loss: 0.0389
Epoch 60/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0374 - val_loss: 0.0387
Epoch 61/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0374 - val_loss: 0.0387
Epoch 62/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0373 - val_loss: 0.0386
Epoch 63/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0373 - val_loss: 0.0387
Epoch 64/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0373 - val_loss: 0.0387
Epoch 65/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0373 - val_loss: 0.0384
Epoch 66/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0372 - val_loss: 0.0385
Epoch 67/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0372 - val_loss: 0.0386
Epoch 68/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0372 - val_loss: 0.0386
Epoch 69/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0372 - val_loss: 0.0389
Epoch 70/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0371 - val_loss: 0.0385
Epoch 71/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0372 - val_loss: 0.0384
Epoch 72/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0371 - val_loss: 0.0387
Epoch 73/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0371 - val_loss: 0.0386
Epoch 74/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0371 - val_loss: 0.0384
Epoch 75/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0371 - val_loss: 0.0387
Epoch 76/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0370 - val_loss: 0.0387
Epoch 77/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0370 - val_loss: 0.0383
Epoch 78/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0370 - val_loss: 0.0385
Epoch 79/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0370 - val_loss: 0.0384
Epoch 80/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0369 - val_loss: 0.0385
Epoch 81/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0369 - val_loss: 0.0384
Epoch 82/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0369 - val_loss: 0.0383
Epoch 83/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0369 - val_loss: 0.0385
Epoch 84/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0369 - val_loss: 0.0385
Epoch 85/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0368 - val_loss: 0.0386
Epoch 86/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0369 - val_loss: 0.0386
Epoch 87/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0368 - val_loss: 0.0384
Epoch 88/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0368 - val_loss: 0.0383
Epoch 89/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0368 - val_loss: 0.0385
Epoch 90/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0368 - val_loss: 0.0384
Epoch 91/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0368 - val_loss: 0.0384
Epoch 92/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0367 - val_loss: 0.0386
Epoch 93/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0367 - val_loss: 0.0385
Epoch 94/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0367 - val_loss: 0.0382
Epoch 95/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0367 - val_loss: 0.0383
Epoch 96/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0367 - val_loss: 0.0383
Epoch 97/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0367 - val_loss: 0.0384
Epoch 98/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0383
Epoch 99/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0385
Epoch 100/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0385
Epoch 101/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0383
Epoch 102/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0384
Epoch 103/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0384
Epoch 104/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0365 - val_loss: 0.0385
Epoch 105/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0366 - val_loss: 0.0386
Epoch 106/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0365 - val_loss: 0.0382
Epoch 107/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0365 - val_loss: 0.0383
Epoch 108/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0365 - val_loss: 0.0384
Epoch 109/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0381
Epoch 110/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0365 - val_loss: 0.0385
Epoch 111/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0365 - val_loss: 0.0385
Epoch 112/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0385
Epoch 113/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0383
Epoch 114/200
1875/1875 [==============================] - 10s 6ms/step - loss: 0.0364 - val_loss: 0.0384
Epoch 115/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0381
Epoch 116/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0385
Epoch 117/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0382
Epoch 118/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0364 - val_loss: 0.0386
Epoch 119/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0383
Epoch 120/200
1875/1875 [==============================] - 10s 6ms/step - loss: 0.0364 - val_loss: 0.0383
Epoch 121/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0385
Epoch 122/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0389
Epoch 123/200
1875/1875 [==============================] - 10s 6ms/step - loss: 0.0363 - val_loss: 0.0385
Epoch 124/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0383
Epoch 125/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0385
Epoch 126/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0384
Epoch 127/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0384
Epoch 128/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0363 - val_loss: 0.0384
Epoch 129/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0384
Epoch 130/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0388
Epoch 131/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0384
Epoch 132/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0384
Epoch 133/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0384
Epoch 134/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0384
Epoch 135/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0386
Epoch 136/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0362 - val_loss: 0.0385
Epoch 137/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0383
Epoch 138/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0383
Epoch 139/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0383
Epoch 140/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0384
Epoch 141/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0385
Epoch 142/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0385
Epoch 143/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0386
Epoch 144/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0382
Epoch 145/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0384
Epoch 146/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0387
Epoch 147/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0361 - val_loss: 0.0384
Epoch 148/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0385
Epoch 149/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0382
Epoch 150/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0381
Epoch 151/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0385
Epoch 152/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0382
Epoch 153/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0385
Epoch 154/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0384
Epoch 155/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0384
Epoch 156/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0382
Epoch 157/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0384
Epoch 158/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0384
Epoch 159/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0383
Epoch 160/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0360 - val_loss: 0.0387
Epoch 161/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0383
Epoch 162/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0383
Epoch 163/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0384
Epoch 164/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0386
Epoch 165/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0383
Epoch 166/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0382
Epoch 167/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0384
Epoch 168/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0382
Epoch 169/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0359 - val_loss: 0.0387
Epoch 170/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0389
Epoch 171/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0381
Epoch 172/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0390
Epoch 173/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0386
Epoch 174/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0382
Epoch 175/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0382
Epoch 176/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0388
Epoch 177/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0383
Epoch 178/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0383
Epoch 179/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0382
Epoch 180/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0358 - val_loss: 0.0384
Epoch 181/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383
Epoch 182/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383
Epoch 183/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383
Epoch 184/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0385
Epoch 185/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383
Epoch 186/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0386
Epoch 187/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383
Epoch 188/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383
Epoch 189/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0386
Epoch 190/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383
Epoch 191/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0383
Epoch 192/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0382
Epoch 193/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0357 - val_loss: 0.0386
Epoch 194/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0383
Epoch 195/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0385
Epoch 196/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0383
Epoch 197/200
1875/1875 [==============================] - 10s 6ms/step - loss: 0.0356 - val_loss: 0.0382
Epoch 198/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0385
Epoch 199/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0384
Epoch 200/200
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0356 - val_loss: 0.0384
In [ ]:
# the return value contains the loss values in the additional training. 
# 追加で行ったtraining時のlossが返り値に含まれる
print(len(history_work.history['loss']))
197
In [ ]:
loss1_1 = history.history['loss']
vloss1_1 = history.history['val_loss']

loss1_2 = history_work.history['loss']
vloss1_2 = history_work.history['val_loss']

loss1 = np.concatenate([loss1_1, loss1_2], axis=0)
val_loss1 = np.concatenate([vloss1_1, vloss1_2], axis=0)
In [ ]:
AutoEncoder.plot_history([loss1, val_loss1], ['loss', 'val_loss'])

Validate the training results.

Training 結果を検証する

In [ ]:
selected_indices = np.random.choice(range(len(x_test)), 10)
selected_images = x_test[selected_indices]
In [ ]:
z_points = AE_work.encoder.predict(selected_images)
reconst_images = AE_work.decoder.predict(z_points)

txts = [ f'{p[0]:.3f}, {p[1]:.3f}' for p in z_points ]
In [ ]:
%matplotlib inline

AutoEncoder.showImages(selected_images, reconst_images, txts, 1.4, 1.4)

(2) Training with tf.GradientTape() function.

Instead of using fit(), calculate the loss in your own train() function, find the gradients, and apply them to the variables.

The train_tf() function is speeding up by declaring <code>@tf.function</code> the compute_loss_and_grads() function.

(2) tf.GradientTape() 関数を使った学習

fit() 関数を使わずに、自分で記述した train() 関数内で loss を計算し、gradients を求めて、変数に適用する。

train_tf() 関数では、lossとgradientsの計算を行う compute_loss_and_grads() 関数を <code>@tf.function</code> 宣言することで高速化を図っている。

In [ ]:
save_path2 = '/content/drive/MyDrive/ColabRun/AE02/'
In [ ]:
from nw.AutoEncoder import AutoEncoder

AE2 = AutoEncoder(
    input_dim = (28, 28, 1),
    encoder_conv_filters = [32, 64, 64, 64],
    encoder_conv_kernel_size = [3, 3, 3, 3],
    encoder_conv_strides = [1, 2, 2, 1],
    decoder_conv_t_filters = [64, 64, 32, 1],
    decoder_conv_t_kernel_size = [3, 3, 3, 3],
    decoder_conv_t_strides = [1, 2, 2, 1],
    z_dim = 2
)
In [ ]:
optimizer2 = tf.keras.optimizers.Adam(learning_rate=learning_rate)
In [ ]:
# At first, train for a few epochs.
# まず、少ない回数 training してみる

loss2_1, vloss2_1 = AE2.train(
    x_train,
    x_train,
    batch_size=32,
    epochs = 3, 
    shuffle=True,
    run_folder= save_path2,
    optimizer = optimizer2,
    save_epoch_interval=50,
    validation_data=(x_test, x_test)
    )
1/3 1875 loss: 0.0406  val loss: 0.0481  0:00:39.580740
2/3 1875 loss: 0.0391  val loss: 0.0448  0:01:18.183180
3/3 1875 loss: 0.0529  val loss: 0.0432  0:01:56.549291
In [ ]:
# Load the parameters and the weights saved before.
# 保存したパラメータと、重みを読み込む。

AE2_work = AutoEncoder.load(save_path2)
print(AE2_work.epoch)
3
In [ ]:
# Additional Training.
# 追加でtrainingする。

# Compiles the part for loss and gradients fo train_tf() function into a graph of Tensorflow 2, so it is a little over twice as fast as train(). However, it is still nearly twice as slow as fit().
# train_tf() は loss と gradients を求める部分を tf のgraphにコンパイルしているので、train()よりも2倍強高速になっている。しかし、それでもfit()よりは2倍近く遅い。

loss2_2, vloss2_2 = AE2_work.train_tf(
    x_train,
    x_train,
    batch_size=32,
    epochs = MAX_EPOCHS, 
    shuffle=True,
    run_folder= save_path2,
    optimizer = optimizer2,
    save_epoch_interval=50,
    validation_data=(x_test, x_test)
    )
4/200 1875 loss: 0.0441  val loss: 0.0430  0:00:16.042304
5/200 1875 loss: 0.0425  val loss: 0.0423  0:00:31.745769
6/200 1875 loss: 0.0419  val loss: 0.0420  0:00:47.550041
7/200 1875 loss: 0.0415  val loss: 0.0414  0:01:03.278043
8/200 1875 loss: 0.0412  val loss: 0.0411  0:01:18.886151
9/200 1875 loss: 0.0409  val loss: 0.0408  0:01:34.403325
10/200 1875 loss: 0.0406  val loss: 0.0404  0:01:49.879346
11/200 1875 loss: 0.0404  val loss: 0.0406  0:02:05.370021
12/200 1875 loss: 0.0403  val loss: 0.0406  0:02:21.087446
13/200 1875 loss: 0.0401  val loss: 0.0401  0:02:36.769709
14/200 1875 loss: 0.0399  val loss: 0.0401  0:02:52.391754
15/200 1875 loss: 0.0398  val loss: 0.0401  0:03:07.901212
16/200 1875 loss: 0.0397  val loss: 0.0400  0:03:23.392766
17/200 1875 loss: 0.0396  val loss: 0.0397  0:03:38.888386
18/200 1875 loss: 0.0394  val loss: 0.0396  0:03:54.470497
19/200 1875 loss: 0.0393  val loss: 0.0397  0:04:10.006842
20/200 1875 loss: 0.0392  val loss: 0.0396  0:04:25.572727
21/200 1875 loss: 0.0391  val loss: 0.0395  0:04:41.166492
22/200 1875 loss: 0.0391  val loss: 0.0395  0:04:56.777322
23/200 1875 loss: 0.0390  val loss: 0.0393  0:05:12.350585
24/200 1875 loss: 0.0389  val loss: 0.0394  0:05:28.061553
25/200 1875 loss: 0.0388  val loss: 0.0400  0:05:43.522782
26/200 1875 loss: 0.0387  val loss: 0.0391  0:05:59.371244
27/200 1875 loss: 0.0387  val loss: 0.0394  0:06:15.072191
28/200 1875 loss: 0.0386  val loss: 0.0394  0:06:30.590397
29/200 1875 loss: 0.0385  val loss: 0.0389  0:06:46.193198
30/200 1875 loss: 0.0385  val loss: 0.0393  0:07:01.744139
31/200 1875 loss: 0.0385  val loss: 0.0392  0:07:17.398188
32/200 1875 loss: 0.0384  val loss: 0.0391  0:07:33.097819
33/200 1875 loss: 0.0383  val loss: 0.0388  0:07:48.744927
34/200 1875 loss: 0.0382  val loss: 0.0388  0:08:04.346553
35/200 1875 loss: 0.0382  val loss: 0.0389  0:08:19.798364
36/200 1875 loss: 0.0381  val loss: 0.0390  0:08:35.371745
37/200 1875 loss: 0.0381  val loss: 0.0386  0:08:51.082935
38/200 1875 loss: 0.0380  val loss: 0.0388  0:09:06.822892
39/200 1875 loss: 0.0380  val loss: 0.0385  0:09:22.394035
40/200 1875 loss: 0.0379  val loss: 0.0389  0:09:37.953852
41/200 1875 loss: 0.0379  val loss: 0.0387  0:09:53.514515
42/200 1875 loss: 0.0378  val loss: 0.0387  0:10:09.035459
43/200 1875 loss: 0.0378  val loss: 0.0386  0:10:24.666275
44/200 1875 loss: 0.0378  val loss: 0.0388  0:10:40.257557
45/200 1875 loss: 0.0377  val loss: 0.0385  0:10:55.745129
46/200 1875 loss: 0.0377  val loss: 0.0388  0:11:11.432226
47/200 1875 loss: 0.0376  val loss: 0.0386  0:11:27.144229
48/200 1875 loss: 0.0376  val loss: 0.0390  0:11:42.676218
49/200 1875 loss: 0.0375  val loss: 0.0388  0:11:58.487087
50/200 1875 loss: 0.0375  val loss: 0.0383  0:12:14.850839
51/200 1875 loss: 0.0375  val loss: 0.0390  0:12:30.468638
52/200 1875 loss: 0.0375  val loss: 0.0388  0:12:46.049661
53/200 1875 loss: 0.0374  val loss: 0.0384  0:13:01.636156
54/200 1875 loss: 0.0374  val loss: 0.0383  0:13:17.325480
55/200 1875 loss: 0.0374  val loss: 0.0385  0:13:32.836645
56/200 1875 loss: 0.0374  val loss: 0.0388  0:13:48.441919
57/200 1875 loss: 0.0373  val loss: 0.0384  0:14:03.917869
58/200 1875 loss: 0.0373  val loss: 0.0388  0:14:19.634660
59/200 1875 loss: 0.0372  val loss: 0.0389  0:14:35.261167
60/200 1875 loss: 0.0372  val loss: 0.0384  0:14:50.896159
61/200 1875 loss: 0.0372  val loss: 0.0390  0:15:06.445663
62/200 1875 loss: 0.0372  val loss: 0.0381  0:15:22.134292
63/200 1875 loss: 0.0372  val loss: 0.0382  0:15:37.757501
64/200 1875 loss: 0.0371  val loss: 0.0384  0:15:53.316315
65/200 1875 loss: 0.0371  val loss: 0.0382  0:16:08.820412
66/200 1875 loss: 0.0371  val loss: 0.0385  0:16:24.565601
67/200 1875 loss: 0.0370  val loss: 0.0384  0:16:40.101123
68/200 1875 loss: 0.0370  val loss: 0.0383  0:16:55.609609
69/200 1875 loss: 0.0370  val loss: 0.0382  0:17:11.264953
70/200 1875 loss: 0.0370  val loss: 0.0383  0:17:26.949355
71/200 1875 loss: 0.0370  val loss: 0.0381  0:17:42.623016
72/200 1875 loss: 0.0369  val loss: 0.0381  0:17:58.321779
73/200 1875 loss: 0.0369  val loss: 0.0382  0:18:13.832138
74/200 1875 loss: 0.0369  val loss: 0.0381  0:18:29.598127
75/200 1875 loss: 0.0369  val loss: 0.0383  0:18:45.208392
76/200 1875 loss: 0.0368  val loss: 0.0385  0:19:00.743062
77/200 1875 loss: 0.0368  val loss: 0.0381  0:19:16.186948
78/200 1875 loss: 0.0368  val loss: 0.0381  0:19:31.760451
79/200 1875 loss: 0.0368  val loss: 0.0385  0:19:47.388234
80/200 1875 loss: 0.0367  val loss: 0.0383  0:20:02.935055
81/200 1875 loss: 0.0367  val loss: 0.0385  0:20:18.402500
82/200 1875 loss: 0.0367  val loss: 0.0381  0:20:33.940910
83/200 1875 loss: 0.0367  val loss: 0.0384  0:20:49.569920
84/200 1875 loss: 0.0367  val loss: 0.0385  0:21:05.242798
85/200 1875 loss: 0.0366  val loss: 0.0382  0:21:20.880114
86/200 1875 loss: 0.0367  val loss: 0.0381  0:21:36.641503
87/200 1875 loss: 0.0366  val loss: 0.0381  0:21:52.095492
88/200 1875 loss: 0.0366  val loss: 0.0379  0:22:07.601546
89/200 1875 loss: 0.0366  val loss: 0.0381  0:22:23.401748
90/200 1875 loss: 0.0366  val loss: 0.0387  0:22:39.066528
91/200 1875 loss: 0.0366  val loss: 0.0387  0:22:54.610725
92/200 1875 loss: 0.0365  val loss: 0.0385  0:23:10.169099
93/200 1875 loss: 0.0365  val loss: 0.0385  0:23:25.674254
94/200 1875 loss: 0.0365  val loss: 0.0381  0:23:41.366783
95/200 1875 loss: 0.0365  val loss: 0.0382  0:23:56.902391
96/200 1875 loss: 0.0365  val loss: 0.0382  0:24:12.496421
97/200 1875 loss: 0.0365  val loss: 0.0383  0:24:28.063963
98/200 1875 loss: 0.0364  val loss: 0.0384  0:24:43.599283
99/200 1875 loss: 0.0365  val loss: 0.0381  0:24:59.157835
100/200 1875 loss: 0.0364  val loss: 0.0379  0:25:15.526026
101/200 1875 loss: 0.0364  val loss: 0.0387  0:25:31.212898
102/200 1875 loss: 0.0364  val loss: 0.0383  0:25:46.802330
103/200 1875 loss: 0.0364  val loss: 0.0382  0:26:02.178094
104/200 1875 loss: 0.0364  val loss: 0.0382  0:26:17.746102
105/200 1875 loss: 0.0363  val loss: 0.0382  0:26:33.309578
106/200 1875 loss: 0.0363  val loss: 0.0384  0:26:49.121648
107/200 1875 loss: 0.0363  val loss: 0.0381  0:27:04.702489
108/200 1875 loss: 0.0363  val loss: 0.0382  0:27:20.170574
109/200 1875 loss: 0.0363  val loss: 0.0379  0:27:35.856174
110/200 1875 loss: 0.0363  val loss: 0.0381  0:27:51.299808
111/200 1875 loss: 0.0362  val loss: 0.0384  0:28:06.870872
112/200 1875 loss: 0.0362  val loss: 0.0381  0:28:22.438025
113/200 1875 loss: 0.0362  val loss: 0.0383  0:28:37.875336
114/200 1875 loss: 0.0362  val loss: 0.0385  0:28:53.328504
115/200 1875 loss: 0.0362  val loss: 0.0382  0:29:08.972971
116/200 1875 loss: 0.0362  val loss: 0.0379  0:29:24.502631
117/200 1875 loss: 0.0362  val loss: 0.0382  0:29:39.941896
118/200 1875 loss: 0.0362  val loss: 0.0381  0:29:55.477538
119/200 1875 loss: 0.0362  val loss: 0.0384  0:30:11.112526
120/200 1875 loss: 0.0361  val loss: 0.0381  0:30:26.374847
121/200 1875 loss: 0.0361  val loss: 0.0380  0:30:41.861327
122/200 1875 loss: 0.0361  val loss: 0.0383  0:30:57.370377
123/200 1875 loss: 0.0361  val loss: 0.0381  0:31:12.900791
124/200 1875 loss: 0.0361  val loss: 0.0380  0:31:28.312363
125/200 1875 loss: 0.0361  val loss: 0.0380  0:31:43.843139
126/200 1875 loss: 0.0361  val loss: 0.0385  0:31:59.553265
127/200 1875 loss: 0.0361  val loss: 0.0385  0:32:14.916876
128/200 1875 loss: 0.0361  val loss: 0.0381  0:32:30.487089
129/200 1875 loss: 0.0360  val loss: 0.0380  0:32:45.878726
130/200 1875 loss: 0.0360  val loss: 0.0382  0:33:01.336908
131/200 1875 loss: 0.0360  val loss: 0.0377  0:33:16.793144
132/200 1875 loss: 0.0360  val loss: 0.0383  0:33:32.367575
133/200 1875 loss: 0.0360  val loss: 0.0383  0:33:47.764421
134/200 1875 loss: 0.0360  val loss: 0.0381  0:34:03.307962
135/200 1875 loss: 0.0360  val loss: 0.0383  0:34:18.773369
136/200 1875 loss: 0.0360  val loss: 0.0380  0:34:34.307721
137/200 1875 loss: 0.0360  val loss: 0.0382  0:34:49.981894
138/200 1875 loss: 0.0360  val loss: 0.0384  0:35:05.470105
139/200 1875 loss: 0.0359  val loss: 0.0383  0:35:20.803749
140/200 1875 loss: 0.0359  val loss: 0.0379  0:35:36.185748
141/200 1875 loss: 0.0359  val loss: 0.0382  0:35:51.533243
142/200 1875 loss: 0.0359  val loss: 0.0380  0:36:06.931450
143/200 1875 loss: 0.0359  val loss: 0.0381  0:36:22.431496
144/200 1875 loss: 0.0359  val loss: 0.0381  0:36:37.869902
145/200 1875 loss: 0.0359  val loss: 0.0384  0:36:53.547983
146/200 1875 loss: 0.0359  val loss: 0.0383  0:37:09.217082
147/200 1875 loss: 0.0359  val loss: 0.0382  0:37:24.778358
148/200 1875 loss: 0.0358  val loss: 0.0379  0:37:40.239433
149/200 1875 loss: 0.0359  val loss: 0.0381  0:37:55.704042
150/200 1875 loss: 0.0358  val loss: 0.0381  0:38:12.101171
151/200 1875 loss: 0.0358  val loss: 0.0380  0:38:27.662723
152/200 1875 loss: 0.0358  val loss: 0.0379  0:38:43.202257
153/200 1875 loss: 0.0358  val loss: 0.0385  0:38:58.810277
154/200 1875 loss: 0.0358  val loss: 0.0380  0:39:14.231378
155/200 1875 loss: 0.0358  val loss: 0.0381  0:39:29.652152
156/200 1875 loss: 0.0358  val loss: 0.0379  0:39:45.085332
157/200 1875 loss: 0.0358  val loss: 0.0380  0:40:00.572288
158/200 1875 loss: 0.0358  val loss: 0.0381  0:40:16.141797
159/200 1875 loss: 0.0357  val loss: 0.0381  0:40:31.634852
160/200 1875 loss: 0.0357  val loss: 0.0381  0:40:47.056919
161/200 1875 loss: 0.0357  val loss: 0.0383  0:41:02.554172
162/200 1875 loss: 0.0358  val loss: 0.0380  0:41:18.121788
163/200 1875 loss: 0.0357  val loss: 0.0379  0:41:33.599777
164/200 1875 loss: 0.0357  val loss: 0.0385  0:41:49.118886
165/200 1875 loss: 0.0357  val loss: 0.0378  0:42:04.560262
166/200 1875 loss: 0.0357  val loss: 0.0381  0:42:20.288644
167/200 1875 loss: 0.0357  val loss: 0.0381  0:42:35.660883
168/200 1875 loss: 0.0357  val loss: 0.0383  0:42:51.115505
169/200 1875 loss: 0.0357  val loss: 0.0380  0:43:06.762465
170/200 1875 loss: 0.0356  val loss: 0.0383  0:43:22.257651
171/200 1875 loss: 0.0356  val loss: 0.0383  0:43:37.670103
172/200 1875 loss: 0.0357  val loss: 0.0380  0:43:53.056826
173/200 1875 loss: 0.0357  val loss: 0.0381  0:44:08.524716
174/200 1875 loss: 0.0356  val loss: 0.0381  0:44:24.027149
175/200 1875 loss: 0.0356  val loss: 0.0379  0:44:39.346028
176/200 1875 loss: 0.0356  val loss: 0.0381  0:44:54.734347
177/200 1875 loss: 0.0356  val loss: 0.0384  0:45:10.213102
178/200 1875 loss: 0.0356  val loss: 0.0379  0:45:25.773002
179/200 1875 loss: 0.0356  val loss: 0.0382  0:45:41.326772
180/200 1875 loss: 0.0356  val loss: 0.0380  0:45:56.666135
181/200 1875 loss: 0.0356  val loss: 0.0382  0:46:11.978621
182/200 1875 loss: 0.0356  val loss: 0.0384  0:46:27.301725
183/200 1875 loss: 0.0356  val loss: 0.0381  0:46:42.745618
184/200 1875 loss: 0.0356  val loss: 0.0380  0:46:58.128569
185/200 1875 loss: 0.0355  val loss: 0.0380  0:47:13.711115
186/200 1875 loss: 0.0355  val loss: 0.0379  0:47:29.307111
187/200 1875 loss: 0.0356  val loss: 0.0382  0:47:44.756529
188/200 1875 loss: 0.0355  val loss: 0.0381  0:48:00.214915
189/200 1875 loss: 0.0355  val loss: 0.0383  0:48:15.668996
190/200 1875 loss: 0.0355  val loss: 0.0381  0:48:31.229319
191/200 1875 loss: 0.0355  val loss: 0.0382  0:48:46.675617
192/200 1875 loss: 0.0355  val loss: 0.0380  0:49:02.254153
193/200 1875 loss: 0.0355  val loss: 0.0382  0:49:17.595616
194/200 1875 loss: 0.0355  val loss: 0.0380  0:49:32.985089
195/200 1875 loss: 0.0355  val loss: 0.0381  0:49:48.470253
196/200 1875 loss: 0.0354  val loss: 0.0382  0:50:03.960498
197/200 1875 loss: 0.0355  val loss: 0.0383  0:50:19.343814
198/200 1875 loss: 0.0355  val loss: 0.0382  0:50:34.860656
199/200 1875 loss: 0.0355  val loss: 0.0381  0:50:50.302304
200/200 1875 loss: 0.0355  val loss: 0.0380  0:51:06.366335
In [ ]:
loss2 = np.concatenate([loss2_1, loss2_2], axis=0)
val_loss2 = np.concatenate([vloss2_1, vloss2_2], axis=0)

AutoEncoder.plot_history([loss2, val_loss2], ['loss', 'val_loss'])
In [ ]:
z_points2 = AE2_work.encoder.predict(selected_images)
reconst_images2 = AE2_work.decoder.predict(z_points2)

txts2 = [ f'{p[0]:.3f}, {p[1]:.3f}' for p in z_points2 ]
In [ ]:
%matplotlib inline

AutoEncoder.showImages(selected_images, reconst_images2, txts2, 1.4, 1.4)

(3) Trainig with tf.GradientTape() function and Learning rate decay

Calculate the loss and gradients with the tf.GradientTape() function, and apply the gradients to the variables. In addition, perform Learning rate decay in the optimizer.

[Caution] Note that if you call the save_image() function in the training, encoder.predict() and decoder.predict() will work and the execution will be slow.

(3) tf.GradientTape() 関数と学習率減衰を使った学習

tf.GradientTape() 関数を使って loss と gradients を計算して、gradients を変数に適用する。 さらに、optimizer において Learning rate decay を行う。

(注意) trainingの途中で save_images()関数を呼び出すと、 encoder.predict()decoder.predict() が動作して、実行が非常に遅くなるので注意すること。

In [ ]:
save_path3 = '/content/drive/MyDrive/ColabRun/AE03/'
In [ ]:
from nw.AutoEncoder import AutoEncoder

AE3 = AutoEncoder(
    input_dim = (28, 28, 1),
    encoder_conv_filters = [32, 64, 64, 64],
    encoder_conv_kernel_size = [3, 3, 3, 3],
    encoder_conv_strides = [1, 2, 2, 1],
    decoder_conv_t_filters = [64, 64, 32, 1],
    decoder_conv_t_kernel_size = [3, 3, 3, 3],
    decoder_conv_t_strides = [1, 2, 2, 1],
    z_dim = 2
)
In [ ]:
# initial_learning_rate * decay_rate ^ (step // decay_steps)

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate = learning_rate,
    decay_steps = 1000,
    decay_rate=0.96
)

optimizer3 = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
In [ ]:
# At first, train for a few epochs.
# まず、少ない回数 training してみる

loss3_1, vloss3_1 = AE3.train(
    x_train,
    x_train,
    batch_size=32,
    epochs = 3, 
    shuffle=True,
    run_folder=save_path3,
    optimizer = optimizer3,
    save_epoch_interval=50,
    validation_data=(x_test, x_test)
    )
1/3 1875 loss: 0.0493  val loss: 0.0491  0:00:38.910057
2/3 1875 loss: 0.0451  val loss: 0.0451  0:01:17.719193
3/3 1875 loss: 0.0463  val loss: 0.0439  0:01:56.448946
In [ ]:
# Load the parameters and the weights saved before.
# 保存したパラメータと、重みを読み込む。

AE3_work = AutoEncoder.load(save_path3)
print(AE3_work.epoch)
3
In [ ]:
# Additional Training.
# 追加でtrainingする。

# Compiles the part for loss and gradients fo train_tf() function into a graph of Tensorflow 2, so it is a little over twice as fast as train(). However, it is still nearly twice as slow as fit().
# train_tf() は loss と gradients を求める部分を tf のgraphにコンパイルしているので、train()よりも2倍強高速になっている。しかし、それでもfit()よりは2倍近く遅い。

loss3_2, vloss3_2 = AE3_work.train_tf(
    x_train,
    x_train,
    batch_size=32,
    epochs = MAX_EPOCHS, 
    shuffle=True,
    run_folder= save_path3,
    optimizer = optimizer3,
    save_epoch_interval=50,
    validation_data=(x_test, x_test)
    )
4/200 1875 loss: 0.0441  val loss: 0.0432  0:00:16.470529
5/200 1875 loss: 0.0425  val loss: 0.0424  0:00:32.578862
6/200 1875 loss: 0.0418  val loss: 0.0416  0:00:48.702078
7/200 1875 loss: 0.0413  val loss: 0.0411  0:01:04.721034
8/200 1875 loss: 0.0410  val loss: 0.0408  0:01:20.828852
9/200 1875 loss: 0.0406  val loss: 0.0408  0:01:36.866276
10/200 1875 loss: 0.0403  val loss: 0.0405  0:01:52.938620
11/200 1875 loss: 0.0401  val loss: 0.0403  0:02:08.915218
12/200 1875 loss: 0.0398  val loss: 0.0403  0:02:24.993334
13/200 1875 loss: 0.0397  val loss: 0.0402  0:02:40.943671
14/200 1875 loss: 0.0395  val loss: 0.0400  0:02:57.026531
15/200 1875 loss: 0.0393  val loss: 0.0396  0:03:13.174460
16/200 1875 loss: 0.0392  val loss: 0.0397  0:03:29.264352
17/200 1875 loss: 0.0391  val loss: 0.0397  0:03:45.378318
18/200 1875 loss: 0.0389  val loss: 0.0395  0:04:01.329874
19/200 1875 loss: 0.0388  val loss: 0.0393  0:04:17.245191
20/200 1875 loss: 0.0387  val loss: 0.0394  0:04:33.400665
21/200 1875 loss: 0.0386  val loss: 0.0392  0:04:49.561788
22/200 1875 loss: 0.0385  val loss: 0.0391  0:05:05.574827
23/200 1875 loss: 0.0384  val loss: 0.0392  0:05:21.680731
24/200 1875 loss: 0.0384  val loss: 0.0390  0:05:37.694906
25/200 1875 loss: 0.0383  val loss: 0.0391  0:05:53.796563
26/200 1875 loss: 0.0382  val loss: 0.0390  0:06:09.769820
27/200 1875 loss: 0.0382  val loss: 0.0388  0:06:25.773891
28/200 1875 loss: 0.0381  val loss: 0.0389  0:06:41.814141
29/200 1875 loss: 0.0380  val loss: 0.0389  0:06:57.785303
30/200 1875 loss: 0.0380  val loss: 0.0389  0:07:13.735852
31/200 1875 loss: 0.0379  val loss: 0.0388  0:07:29.669668
32/200 1875 loss: 0.0379  val loss: 0.0388  0:07:45.531500
33/200 1875 loss: 0.0379  val loss: 0.0388  0:08:01.576522
34/200 1875 loss: 0.0378  val loss: 0.0387  0:08:17.534843
35/200 1875 loss: 0.0378  val loss: 0.0387  0:08:33.530811
36/200 1875 loss: 0.0378  val loss: 0.0387  0:08:49.452516
37/200 1875 loss: 0.0377  val loss: 0.0386  0:09:05.483624
38/200 1875 loss: 0.0377  val loss: 0.0387  0:09:21.607403
39/200 1875 loss: 0.0377  val loss: 0.0386  0:09:37.595950
40/200 1875 loss: 0.0376  val loss: 0.0386  0:09:53.711715
41/200 1875 loss: 0.0376  val loss: 0.0386  0:10:09.573170
42/200 1875 loss: 0.0376  val loss: 0.0386  0:10:25.672371
43/200 1875 loss: 0.0376  val loss: 0.0386  0:10:41.658941
44/200 1875 loss: 0.0376  val loss: 0.0386  0:10:57.703071
45/200 1875 loss: 0.0375  val loss: 0.0386  0:11:13.708545
46/200 1875 loss: 0.0375  val loss: 0.0386  0:11:29.633509
47/200 1875 loss: 0.0375  val loss: 0.0386  0:11:45.618107
48/200 1875 loss: 0.0375  val loss: 0.0386  0:12:01.542282
49/200 1875 loss: 0.0375  val loss: 0.0386  0:12:17.577870
50/200 1875 loss: 0.0375  val loss: 0.0386  0:12:34.355703
51/200 1875 loss: 0.0375  val loss: 0.0386  0:12:50.369939
52/200 1875 loss: 0.0374  val loss: 0.0385  0:13:06.322242
53/200 1875 loss: 0.0374  val loss: 0.0385  0:13:22.294449
54/200 1875 loss: 0.0374  val loss: 0.0386  0:13:38.360678
55/200 1875 loss: 0.0374  val loss: 0.0385  0:13:54.381169
56/200 1875 loss: 0.0374  val loss: 0.0385  0:14:10.436163
57/200 1875 loss: 0.0374  val loss: 0.0385  0:14:26.374065
58/200 1875 loss: 0.0374  val loss: 0.0385  0:14:42.461288
59/200 1875 loss: 0.0374  val loss: 0.0385  0:14:58.804114
60/200 1875 loss: 0.0374  val loss: 0.0385  0:15:14.948368
61/200 1875 loss: 0.0374  val loss: 0.0385  0:15:31.003709
62/200 1875 loss: 0.0374  val loss: 0.0385  0:15:47.036744
63/200 1875 loss: 0.0374  val loss: 0.0385  0:16:02.927857
64/200 1875 loss: 0.0374  val loss: 0.0385  0:16:19.033550
65/200 1875 loss: 0.0374  val loss: 0.0385  0:16:35.146949
66/200 1875 loss: 0.0374  val loss: 0.0385  0:16:51.198622
67/200 1875 loss: 0.0374  val loss: 0.0385  0:17:07.332641
68/200 1875 loss: 0.0374  val loss: 0.0385  0:17:23.460586
69/200 1875 loss: 0.0373  val loss: 0.0385  0:17:39.446418
70/200 1875 loss: 0.0373  val loss: 0.0385  0:17:55.528039
71/200 1875 loss: 0.0373  val loss: 0.0385  0:18:11.502721
72/200 1875 loss: 0.0373  val loss: 0.0385  0:18:27.473277
73/200 1875 loss: 0.0373  val loss: 0.0385  0:18:43.449839
74/200 1875 loss: 0.0373  val loss: 0.0385  0:18:59.486046
75/200 1875 loss: 0.0373  val loss: 0.0385  0:19:15.552210
76/200 1875 loss: 0.0373  val loss: 0.0385  0:19:31.626925
77/200 1875 loss: 0.0373  val loss: 0.0385  0:19:47.513213
78/200 1875 loss: 0.0373  val loss: 0.0385  0:20:03.722407
79/200 1875 loss: 0.0373  val loss: 0.0385  0:20:19.804345
80/200 1875 loss: 0.0373  val loss: 0.0385  0:20:35.791387
81/200 1875 loss: 0.0373  val loss: 0.0385  0:20:51.787138
82/200 1875 loss: 0.0373  val loss: 0.0385  0:21:07.730577
83/200 1875 loss: 0.0373  val loss: 0.0385  0:21:23.787720
84/200 1875 loss: 0.0373  val loss: 0.0385  0:21:39.687327
85/200 1875 loss: 0.0373  val loss: 0.0385  0:21:55.584867
86/200 1875 loss: 0.0373  val loss: 0.0385  0:22:11.523757
87/200 1875 loss: 0.0373  val loss: 0.0385  0:22:27.390298
88/200 1875 loss: 0.0373  val loss: 0.0385  0:22:43.472842
89/200 1875 loss: 0.0373  val loss: 0.0385  0:22:59.672602
90/200 1875 loss: 0.0373  val loss: 0.0385  0:23:15.707643
91/200 1875 loss: 0.0373  val loss: 0.0385  0:23:31.828762
92/200 1875 loss: 0.0373  val loss: 0.0385  0:23:47.825785
93/200 1875 loss: 0.0373  val loss: 0.0385  0:24:03.841613
94/200 1875 loss: 0.0373  val loss: 0.0385  0:24:19.741862
95/200 1875 loss: 0.0373  val loss: 0.0385  0:24:35.758435
96/200 1875 loss: 0.0373  val loss: 0.0385  0:24:51.756826
97/200 1875 loss: 0.0373  val loss: 0.0385  0:25:07.796537
98/200 1875 loss: 0.0373  val loss: 0.0385  0:25:24.006386
99/200 1875 loss: 0.0373  val loss: 0.0385  0:25:39.953159
100/200 1875 loss: 0.0373  val loss: 0.0385  0:25:56.680088
101/200 1875 loss: 0.0373  val loss: 0.0385  0:26:12.840496
102/200 1875 loss: 0.0373  val loss: 0.0385  0:26:28.849194
103/200 1875 loss: 0.0373  val loss: 0.0385  0:26:44.841663
104/200 1875 loss: 0.0373  val loss: 0.0385  0:27:00.859847
105/200 1875 loss: 0.0373  val loss: 0.0385  0:27:17.004825
106/200 1875 loss: 0.0373  val loss: 0.0385  0:27:33.131367
107/200 1875 loss: 0.0373  val loss: 0.0385  0:27:49.215559
108/200 1875 loss: 0.0373  val loss: 0.0385  0:28:05.210165
109/200 1875 loss: 0.0373  val loss: 0.0385  0:28:21.219173
110/200 1875 loss: 0.0373  val loss: 0.0385  0:28:37.222607
111/200 1875 loss: 0.0373  val loss: 0.0385  0:28:53.464457
112/200 1875 loss: 0.0373  val loss: 0.0385  0:29:09.544785
113/200 1875 loss: 0.0373  val loss: 0.0385  0:29:25.686480
114/200 1875 loss: 0.0373  val loss: 0.0385  0:29:41.597519
115/200 1875 loss: 0.0373  val loss: 0.0385  0:29:57.662729
116/200 1875 loss: 0.0373  val loss: 0.0385  0:30:13.863950
117/200 1875 loss: 0.0373  val loss: 0.0385  0:30:30.316802
118/200 1875 loss: 0.0373  val loss: 0.0385  0:30:46.468128
119/200 1875 loss: 0.0373  val loss: 0.0385  0:31:02.499419
120/200 1875 loss: 0.0373  val loss: 0.0385  0:31:18.614042
121/200 1875 loss: 0.0373  val loss: 0.0385  0:31:35.006121
122/200 1875 loss: 0.0373  val loss: 0.0385  0:31:51.303524
123/200 1875 loss: 0.0373  val loss: 0.0385  0:32:07.364468
124/200 1875 loss: 0.0373  val loss: 0.0385  0:32:23.438019
125/200 1875 loss: 0.0373  val loss: 0.0385  0:32:39.503072
126/200 1875 loss: 0.0373  val loss: 0.0385  0:32:55.600306
127/200 1875 loss: 0.0373  val loss: 0.0385  0:33:11.714923
128/200 1875 loss: 0.0373  val loss: 0.0385  0:33:27.757243
129/200 1875 loss: 0.0373  val loss: 0.0385  0:33:43.846899
130/200 1875 loss: 0.0373  val loss: 0.0385  0:33:59.986440
131/200 1875 loss: 0.0373  val loss: 0.0385  0:34:16.036559
132/200 1875 loss: 0.0373  val loss: 0.0385  0:34:32.108351
133/200 1875 loss: 0.0373  val loss: 0.0385  0:34:48.126711
134/200 1875 loss: 0.0373  val loss: 0.0385  0:35:04.243871
135/200 1875 loss: 0.0373  val loss: 0.0385  0:35:20.342651
136/200 1875 loss: 0.0373  val loss: 0.0385  0:35:36.834930
137/200 1875 loss: 0.0373  val loss: 0.0385  0:35:53.240132
138/200 1875 loss: 0.0373  val loss: 0.0385  0:36:09.494644
139/200 1875 loss: 0.0373  val loss: 0.0385  0:36:25.711864
140/200 1875 loss: 0.0373  val loss: 0.0385  0:36:41.715972
141/200 1875 loss: 0.0373  val loss: 0.0385  0:36:57.798700
142/200 1875 loss: 0.0373  val loss: 0.0385  0:37:14.012655
143/200 1875 loss: 0.0373  val loss: 0.0385  0:37:29.964565
144/200 1875 loss: 0.0373  val loss: 0.0385  0:37:45.947853
145/200 1875 loss: 0.0373  val loss: 0.0385  0:38:01.922393
146/200 1875 loss: 0.0373  val loss: 0.0385  0:38:18.113317
147/200 1875 loss: 0.0373  val loss: 0.0385  0:38:34.302843
148/200 1875 loss: 0.0373  val loss: 0.0385  0:38:50.460746
149/200 1875 loss: 0.0373  val loss: 0.0385  0:39:06.514157
150/200 1875 loss: 0.0373  val loss: 0.0385  0:39:23.229584
151/200 1875 loss: 0.0373  val loss: 0.0385  0:39:39.470055
152/200 1875 loss: 0.0373  val loss: 0.0385  0:39:55.751577
153/200 1875 loss: 0.0373  val loss: 0.0385  0:40:12.016928
154/200 1875 loss: 0.0373  val loss: 0.0385  0:40:28.212051
155/200 1875 loss: 0.0373  val loss: 0.0385  0:40:44.578358
156/200 1875 loss: 0.0373  val loss: 0.0385  0:41:00.970150
157/200 1875 loss: 0.0373  val loss: 0.0385  0:41:17.138623
158/200 1875 loss: 0.0373  val loss: 0.0385  0:41:33.213235
159/200 1875 loss: 0.0373  val loss: 0.0385  0:41:49.222445
160/200 1875 loss: 0.0373  val loss: 0.0385  0:42:05.256580
161/200 1875 loss: 0.0373  val loss: 0.0385  0:42:21.305456
162/200 1875 loss: 0.0373  val loss: 0.0385  0:42:37.293515
163/200 1875 loss: 0.0373  val loss: 0.0385  0:42:53.218787
164/200 1875 loss: 0.0373  val loss: 0.0385  0:43:09.221011
165/200 1875 loss: 0.0373  val loss: 0.0385  0:43:25.241022
166/200 1875 loss: 0.0373  val loss: 0.0385  0:43:41.364001
167/200 1875 loss: 0.0373  val loss: 0.0385  0:43:57.468188
168/200 1875 loss: 0.0373  val loss: 0.0385  0:44:13.478234
169/200 1875 loss: 0.0373  val loss: 0.0385  0:44:29.480388
170/200 1875 loss: 0.0373  val loss: 0.0385  0:44:45.507111
171/200 1875 loss: 0.0373  val loss: 0.0385  0:45:01.481223
172/200 1875 loss: 0.0373  val loss: 0.0385  0:45:17.528489
173/200 1875 loss: 0.0373  val loss: 0.0385  0:45:33.518117
174/200 1875 loss: 0.0373  val loss: 0.0385  0:45:49.637809
175/200 1875 loss: 0.0373  val loss: 0.0385  0:46:05.890984
176/200 1875 loss: 0.0373  val loss: 0.0385  0:46:21.956460
177/200 1875 loss: 0.0373  val loss: 0.0385  0:46:38.062078
178/200 1875 loss: 0.0373  val loss: 0.0385  0:46:54.053728
179/200 1875 loss: 0.0373  val loss: 0.0385  0:47:09.917701
180/200 1875 loss: 0.0373  val loss: 0.0385  0:47:25.930845
181/200 1875 loss: 0.0373  val loss: 0.0385  0:47:41.949047
182/200 1875 loss: 0.0373  val loss: 0.0385  0:47:57.975498
183/200 1875 loss: 0.0373  val loss: 0.0385  0:48:13.821878
184/200 1875 loss: 0.0373  val loss: 0.0385  0:48:29.675598
185/200 1875 loss: 0.0373  val loss: 0.0385  0:48:45.563435
186/200 1875 loss: 0.0373  val loss: 0.0385  0:49:01.501470
187/200 1875 loss: 0.0373  val loss: 0.0385  0:49:17.551920
188/200 1875 loss: 0.0373  val loss: 0.0385  0:49:33.465461
189/200 1875 loss: 0.0373  val loss: 0.0385  0:49:49.437168
190/200 1875 loss: 0.0373  val loss: 0.0385  0:50:05.438445
191/200 1875 loss: 0.0373  val loss: 0.0385  0:50:21.633043
192/200 1875 loss: 0.0373  val loss: 0.0385  0:50:37.783884
193/200 1875 loss: 0.0373  val loss: 0.0385  0:50:53.888186
194/200 1875 loss: 0.0373  val loss: 0.0385  0:51:10.185000
195/200 1875 loss: 0.0373  val loss: 0.0385  0:51:26.449233
196/200 1875 loss: 0.0373  val loss: 0.0385  0:51:42.700071
197/200 1875 loss: 0.0373  val loss: 0.0385  0:51:58.913375
198/200 1875 loss: 0.0373  val loss: 0.0385  0:52:15.079322
199/200 1875 loss: 0.0373  val loss: 0.0385  0:52:30.994392
200/200 1875 loss: 0.0373  val loss: 0.0385  0:52:47.570349
In [ ]:
loss3 = np.concatenate([loss3_1, loss3_2], axis=0)
val_loss3 = np.concatenate([vloss3_1, vloss3_2], axis=0)

AutoEncoder.plot_history([loss3, val_loss3], ['loss', 'val_loss'])
In [ ]:
z_points3 = AE3_work.encoder.predict(selected_images)
reconst_images3 = AE3_work.decoder.predict(z_points3)

txts3 = [ f'{p[0]:.3f}, {p[1]:.3f}' for p in z_points3 ]
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f7c0c6eae60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f7c0c6afa70> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
In [ ]:
%matplotlib inline

AutoEncoder.showImages(selected_images, reconst_images3, txts3, 1.4, 1.4)
In [ ]: