Updated 19/Nov/2021 by Yoshihisa Nitta  

Variational Auto Encoder Training for MNIST dataset with Tensorflow 2 on Google Colab

Train Variational Auto Encoder on MNIST dataset. Variational Auto Encoder is a modification of the encoder and loss function for Auto Encoder.

MNIST データセットに対して Variational Auto Encoder をGoogle Colab 上の Tensorflow 2 で学習する

MNIST データセットに対して変分オートエンコーダを学習させる。 Variational Auto Encoder とは、AutoEncoder に対して encoder と losss function (損失関数) に変更を加えたものである。

In [1]:
#! pip install tensorflow==2.7.0
In [2]:
%tensorflow_version 2.x

import tensorflow as tf
print(tf.__version__)
2.7.0

About Encoder of VariationalAutoEncoder

With AutoEncoder, eath image is directly mapped to one point in the latent space. With VariationalAutoEncoder, each image is mapped to a multivariate normal distribution around a point in latent space.

The covariance matrix is a diagonal matrix because VariationalAutoEncoder considers that there is no correlation between any dimensions of the latent space. As a result, the encoder only needs to map each input to the mean vector and the variance vector, and does not have to worry about the correlation between the dimensions. Furthermore, when mapped to the logarithm of the variance, any real number in the range($-\infty$, $\infty$) can be assigned.

VariationalAutoEncoder の Encoder について

AutoEncoder では、各画像は潜在空間の1点に直接写像される。 VariationalAutoEncoder では、各画像は潜在空間のある点の周りの多変量正規分布に写像される。

変分オートエンコーダでは、潜在空間のどの次元間にも相関がないとみなすので、共分散行列は対角行列になる。 これにより、エンコーダは各入力を平均ベクトルと分散ベクトルに写像すればよく、次元間の相関を気にする必要はない。 さらに、分散の対数に写像すると (−∞,∞) の範囲のどのような実数でもとれる。

Covariance Matrix

The variance-covariance matrix is a matrix obtained by extending the concept of variance (an index showing the degree of distribution) to multidimensional random varaible.

Difinition for 2 random variables

For the random variables $X_1$ and $X_2$, the variance-covariance matrix is define as follows.

$\Sigma = \left( \begin{array}{cc} \sigma_1^2 & \sigma_{12} \\ \sigma_{12} & \sigma_2^2 \\ \end{array} \right )$

where
$\sigma_1^2 = \mbox{variance of } X_1 $,
$\sigma_2^2 = \mbox{variance of } X_2 $,
$\sigma_{12} = \mbox{covariance of } X_1 \mbox{ and } X_2$

$\Sigma$ is called a variance-covariance matrix because the variances are lined up on the diagonal components and the covariances are lined up on the off-diagonal components.

It is defined in the same way when there are $n$ random variables.

For the random variables $X_1$, $\cdots$, $X_n$, An $n\times n$ matrix is called a variance-covariance matris where the $ii$ component is $\sigma_i^2$, and the $ij$ component ($i\neq j$) is $\sigma_{ij}$.

Example: Find the variance and covariance of 5 data in 2 variables.

Suppose that $(x_i, y_i) = (4,5), (5, 7), (6,6), (7,9), (8,8)$ are given as data.

Mean of $x$ : $\mu_x = \displaystyle \frac{1}{5} (4+5+6+7+8)=6$,
Mean of $y$ : $\mu_y = \displaystyle \frac{1}{5}(5 + 7+6+9+8)=7$

Variance of $x$:
$\sigma_x^2 = \displaystyle\frac{1}{5}\sum_{k=1}^5 (x_i - \mu_x)^2 \\ \quad = \displaystyle \frac{1}{5} ((4-6)^2+(5-6)^2+(6-6)^2 + (7-6)^2 +(8-6)^2 ) \\ \quad = \displaystyle \frac{1}{5} (2^2 + 1^2 + 0^2 + 1^2 + 2^2) = 2$
Variance of $y$
$\sigma_y^2 = \displaystyle\frac{1}{5} \sum_{k=1}^{5} (y_i - \mu_y)^2 \\ \quad = \displaystyle{1}{5}((5-7)^2 +(7-7)^2 +(6-7)^2 +(9-7)^2 +(8-7)^2 ) \\ \quad = \displaystyle \frac{1}{5} (2^2 + 0^2 + 1^2 + 2^2 + 1^2) = 2$

Covariance of $x$ and $y$:
$\sigma_{xy} = \displaystyle \frac{1}{5} \sum_{k=1}^5 (x_i - \mu_x) (y_i - \mu_y) \\ \quad = \displaystyle \frac{1}{5} ((4-6)(5-7)+(5-6)(7-7)+(6-6)(6-7)+(7-6)(9-7)+(8-6)(8-7)) \\ \quad = \displaystyle \frac{1}{5}((-2)(-2) + (-1)\cdot 0 + 0 \cdot (-1) + 1 \cdot 2 + 2 \cdot 1)=\frac{8}{5}=1.6 $

Variance and covariance matrix of $x$ and $y$:
$\Sigma = \left( \begin{array}{cc} \sigma_x^2 & \sigma_{xy} \\ \sigma_{xy} & \sigma_y^2 \\ \end{array} \right) = \left( \begin{array}{cc} 2 & 1.6 \\ 1.6 & 2 \\ \end{array} \right)$

分散共分散行列

分散共分散行列とは、分散(散らばり具合を表す指標)の概念を多次元確率変数に拡張して行列としたもの。単に共分散行列と呼ぶこともある。

確率変数が2つの場合の定義

確率変数 $X_1$, $X_2$ に対して、分散共分散行列 を $\Sigma = \left( \begin{array}{cc} \sigma_1^2 & \sigma_{12} \\ \sigma_{12} & \sigma_2^2 \\ \end{array} \right )$ と定義する。 ただし、
$\sigma_1^2 = X_1 \mbox{の分散}$,
$\sigma_2^2 = X_2 \mbox{の分散}$,
$\sigma_{12} = X_1 \mbox{と} X_2 \mbox{の共分散}$
を表す。対角成分に分散が並び、非対角成分には共分散が並ぶため、分散共分散行列と呼ばれる。

確率変数が $n$ 個の場合も同様に定義される。

確率変数 $X_1$, $\cdots$, $X_n$ に対して、 第 $ii$ 成分が $\sigma_i^2$,
第 $ij$ 成分 ($i \neq j$)が $\sigma_{ij}$
であるような $n\times n$行列 $\Sigma$ を 分散共分散行列と呼ぶ。

例題: 2変数の5個のデータの分散と共分散を求めよ。

データとして $(x_i, y_i) = (4, 5), (5, 7), (6, 6), (7,9), (8, 8)$ が与えられたとする。

$x$の平均 $\mu_x = \displaystyle\frac{1}{5} (4 + 5 + 6 + 7 + 8) = 6$,
$y$ の平均 $\mu_y = \displaystyle\frac{1}{5} (5 + 7 + 6 + 9 + 8) = 7$

$x$ の分散は
$\sigma_x^2 = \displaystyle\frac{1}{5}\sum_{k=1}^5 (x_i - \mu_x)^2 \\ \quad = \displaystyle \frac{1}{5} ((4-6)^2+(5-6)^2+(6-6)^2 + (7-6)^2 +(8-6)^2 ) \\ \quad = \displaystyle \frac{1}{5} (2^2 + 1^2 + 0^2 + 1^2 + 2^2) = 2$
$y$の分散は
$\sigma_y^2 = \displaystyle\frac{1}{5} \sum_{k=1}^{5} (y_i - \mu_y)^2 \\ \quad = \displaystyle{1}{5}((5-7)^2 +(7-7)^2 +(6-7)^2 +(9-7)^2 +(8-7)^2 ) \\ \quad = \displaystyle \frac{1}{5} (2^2 + 0^2 + 1^2 + 2^2 + 1^2) = 2$

$x$ と $y$の共分散は
$\sigma_{xy} = \displaystyle \frac{1}{5} \sum_{k=1}^5 (x_i - \mu_x) (y_i - \mu_y) \\ \quad = \displaystyle \frac{1}{5} ((4-6)(5-7)+(5-6)(7-7)+(6-6)(6-7)+(7-6)(9-7)+(8-6)(8-7)) \\ \quad = \displaystyle \frac{1}{5}((-2)(-2) + (-1)\cdot 0 + 0 \cdot (-1) + 1 \cdot 2 + 2 \cdot 1)=\frac{8}{5}=1.6 $

$x$ と $y$ の分散共分散行列は
$\Sigma = \left( \begin{array}{cc} \sigma_x^2 & \sigma_{xy} \\ \sigma_{xy} & \sigma_y^2 \\ \end{array} \right) = \left( \begin{array}{cc} 2 & 1.6 \\ 1.6 & 2 \\ \end{array} \right)$

Normal Distribution

Probability density function of one-dimensional normal distribution
$\displaystyle f(x | \mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2}} e^{-\frac{(x-\mu)^2}{2\sigma^2}}$
where mean $\mu$, variance $\sigma^2$, standard deviation $\sigma$.

正規分布

平均(mean) $\mu$, 分散(variance) $\sigma^2$, 標準偏差(standard deviatioin) $\sigma$ として1次元の正規分布の確率密度関数
$\displaystyle f(x | \mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2}} e^{-\frac{(x-\mu)^2}{2\sigma^2}}$

In [3]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(-5, 5, 200)

def f(x, m, v):
    d = x-m
    return np.exp(- d*d / (2.0 * v)) / np.sqrt(2 * np.pi * v)

fig, ax = plt.subplots(1,1,figsize=(8,6))

ax.plot(x, f(x, 0.0, 0.2),label='str mean=0.0, variance=0.2',color='blue')
ax.plot(x, f(x, 0.0, 1.0),label='str mean=0.0, variance=1.0',color='red')
ax.plot(x, f(x, 0.0, 5.0),label='str mean=0.0, variance=5.0',color='orange')
ax.plot(x, f(x, -2.0, 0.5),label='str mean=-2.0, variance=0.5',color='green')

plt.legend()
plt.show()
    

VariationalAutoEncoder

Sample the point $z$ as follows.

$ z = \mu + \sigma \epsilon$

$\mu$ represents where to place the marker, $\sigma$ is its certainty, and $\epsilon$ is a randomly selected value according to the probability distribution. The point is tried to place around the $\mu$, so it is expected that the latent space will be continuous.

Since the relationship $x = \displaystyle e^{\log x}$ holds, the following formula holds.
$\sigma = \displaystyle e^{\log \sigma} = \displaystyle e^{\frac{2 \log \sigma}{2}} = \displaystyle e^{\frac{\log \sigma^2}{2}}$

Therefore, using the calculated $\log$ of variance $\sigma^2$, calculate the following equation for each dimension.
$\mbox{sigma} = \sigma = \mbox{exp(log_var/2)}$

The features of the new variational encoder are as follows.

  • Instead of connecting the Flatten layer directly to the layer in the latent space, connect it to the layers of mu and log_var.
  • The Sampling layer samples points in the latent space from the normal distribution defined by mu and log_var.
  • There are 3 types of encoder model outputs: mu, log_var and z.

変分オートエンコーダ

次の式を使って点 $z$ をサンプリングする。

$ z = \mu + \sigma \epsilon$

$\mu$ は目印をどこに置くか表し、$\sigma$ はその確信度、$\epsilon$ はどのぐらい離れておくかを確率分布に従ってランダムに選んだ値となる。 $\mu$ を目標としてその周囲に置こうとするので、 潜在空間が連続となることが期待される。

$x = \displaystyle e^{\log x}$ という関係が成り立つので $\sigma = \displaystyle e^{\log \sigma} = \displaystyle e^{\frac{2 \log \sigma}{2}} = \displaystyle e^{\frac{\log \sigma^2}{2}}$ が言える。 したがって、 分散 $\sigma^2$ の $\log$ を計算したもの $\mbox{log_var} = \log \sigma^2$ を使って、 $\mbox{sigma} = \sigma = \mbox{exp(log_var/2)}$ を各次元について計算する。

新しい変分エンコーダの特徴は次の通り。

  • Flatten 層を直接潜在空間の層に接続するのではなく mu と log_var の層に接続する。
  • Sampling 層は、mu と log_var で定義される正規分布から潜在空間内の点をサンプリングする。
  • encoderのモデルの出力は mu, log_var, z の3種類となる。

Check the Google Colab runtime environment

Google Colab 実行環境を調べる

In [4]:
! nvidia-smi
! cat /proc/cpuinfo
! cat /etc/issue
! free -h
Mon Nov 22 06:48:42 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    28W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 85
model name	: Intel(R) Xeon(R) CPU @ 2.00GHz
stepping	: 3
microcode	: 0x1
cpu MHz		: 2000.188
cache size	: 39424 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 0
initial apicid	: 0
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4000.37
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 bits virtual
power management:

processor	: 1
vendor_id	: GenuineIntel
cpu family	: 6
model		: 85
model name	: Intel(R) Xeon(R) CPU @ 2.00GHz
stepping	: 3
microcode	: 0x1
cpu MHz		: 2000.188
cache size	: 39424 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 1
initial apicid	: 1
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4000.37
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 bits virtual
power management:

Ubuntu 18.04.5 LTS \n \l

              total        used        free      shared  buff/cache   available
Mem:            12G        785M        6.6G        1.2M        5.4G         11G
Swap:            0B          0B          0B

Mount Google Drive from Google Colab

Google Colab から GoogleDrive をマウントする

In [5]:
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
In [6]:
! ls /content/drive
MyDrive  Shareddrives

Download source file from Google Drive or nw.tsuda.ac.jp

Basically, gdown from Google Drive. Download from nw.tsuda.ac.jp above only if the specifications of Google Drive change and you cannot download from Google Drive.

Google Drive または nw.tsuda.ac.jp からファイルをダウンロードする

基本的に Google Drive から gdown してください。 Google Drive の仕様が変わってダウンロードができない場合にのみ、nw.tsuda.ac.jp からダウンロードしてください。

In [7]:
# Download source file
nw_path = './nw'
! rm -rf {nw_path}
! mkdir -p {nw_path}

if True:   # from Google Drive
    url_model =  'https://drive.google.com/uc?id=1ZCihR7JkMOity4wCr66ZCp-3ZOlfwwo3'
    ! (cd {nw_path}; gdown {url_model})
else:      # from nw.tsuda.ac.jp
    URL_NW = 'https://nw.tsuda.ac.jp/lec/GoogleColab/pub'
    url_model = f'{URL_NW}/models/VariationalAutoEncoder.py'
    ! wget -nd {url_model} -P {nw_path}
Downloading...
From: https://drive.google.com/uc?id=1ZCihR7JkMOity4wCr66ZCp-3ZOlfwwo3
To: /content/nw/VariationalAutoEncoder.py
100% 18.7k/18.7k [00:00<00:00, 17.3MB/s]
In [8]:
! cat {nw_path}/VariationalAutoEncoder.py
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import os
import pickle
import datetime

class Sampling(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, inputs):
        mu, log_var = inputs
        epsilon = tf.keras.backend.random_normal(shape=tf.keras.backend.shape(mu), mean=0., stddev=1.)
        return mu + tf.keras.backend.exp(log_var / 2) * epsilon


class VAEModel(tf.keras.models.Model):
    def __init__(self, encoder, decoder, r_loss_factor, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.r_loss_factor = r_loss_factor


    @tf.function
    def loss_fn(self, x):
        z_mean, z_log_var, z = self.encoder(x)
        reconstruction = self.decoder(z)
        reconstruction_loss = tf.reduce_mean(
            tf.square(x - reconstruction), axis=[1,2,3]
        ) * self.r_loss_factor
        kl_loss = tf.reduce_sum(
            1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var),
            axis = 1
        ) * (-0.5)
        total_loss = reconstruction_loss + kl_loss
        return total_loss, reconstruction_loss, kl_loss


    @tf.function
    def compute_loss_and_grads(self, x):
        with tf.GradientTape() as tape:
            total_loss, reconstruction_loss, kl_loss = self.loss_fn(x)
        grads = tape.gradient(total_loss, self.trainable_weights)
        return total_loss, reconstruction_loss, kl_loss, grads


    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        total_loss, reconstruction_loss, kl_loss, grads = self.compute_loss_and_grads(data)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": tf.math.reduce_mean(total_loss),
            "reconstruction_loss": tf.math.reduce_mean(reconstruction_loss),
            "kl_loss": tf.math.reduce_mean(kl_loss),
        }

    def call(self,inputs):
        _, _, z = self.encoder(inputs)
        return self.decoder(z)


class VariationalAutoEncoder():
    def __init__(self, 
                 input_dim,
                 encoder_conv_filters,
                 encoder_conv_kernel_size,
                 encoder_conv_strides,
                 decoder_conv_t_filters,
                 decoder_conv_t_kernel_size,
                 decoder_conv_t_strides,
                 z_dim,
                 r_loss_factor,   ### added
                 use_batch_norm = False,
                 use_dropout = False,
                 epoch = 0
                ):
        self.name = 'variational_autoencoder'
        self.input_dim = input_dim
        self.encoder_conv_filters = encoder_conv_filters
        self.encoder_conv_kernel_size = encoder_conv_kernel_size
        self.encoder_conv_strides = encoder_conv_strides
        self.decoder_conv_t_filters = decoder_conv_t_filters
        self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size
        self.decoder_conv_t_strides = decoder_conv_t_strides
        self.z_dim = z_dim
        self.r_loss_factor = r_loss_factor   ### added
            
        self.use_batch_norm = use_batch_norm
        self.use_dropout = use_dropout

        self.epoch = epoch
            
        self.n_layers_encoder = len(encoder_conv_filters)
        self.n_layers_decoder = len(decoder_conv_t_filters)
            
        self._build()
 

    def _build(self):
        ### THE ENCODER
        encoder_input = tf.keras.layers.Input(shape=self.input_dim, name='encoder_input')
        x = encoder_input
        
        for i in range(self.n_layers_encoder):
            x = conv_layer = tf.keras.layers.Conv2D(
                filters = self.encoder_conv_filters[i],
                kernel_size = self.encoder_conv_kernel_size[i],
                strides = self.encoder_conv_strides[i],
                padding  = 'same',
                name = 'encoder_conv_' + str(i)
            )(x)

            if self.use_batch_norm:                                ### The order of layers is opposite to AutoEncoder
                x = tf.keras.layers.BatchNormalization()(x)        ###   AE: LeakyReLU -> BatchNorm
            x = tf.keras.layers.LeakyReLU()(x)                     ###   VAE: BatchNorm -> LeakyReLU
            
            if self.use_dropout:
                x = tf.keras.layers.Dropout(rate = 0.25)(x)
        
        shape_before_flattening = tf.keras.backend.int_shape(x)[1:]
        
        x = tf.keras.layers.Flatten()(x)
        
        self.mu = tf.keras.layers.Dense(self.z_dim, name='mu')(x)
        self.log_var = tf.keras.layers.Dense(self.z_dim, name='log_var')(x) 
        self.z = Sampling(name='encoder_output')([self.mu, self.log_var])
        
        self.encoder = tf.keras.models.Model(encoder_input, [self.mu, self.log_var, self.z], name='encoder')
        
        
        ### THE DECODER
        decoder_input = tf.keras.layers.Input(shape=(self.z_dim,), name='decoder_input')
        x = decoder_input
        x = tf.keras.layers.Dense(np.prod(shape_before_flattening))(x)
        x = tf.keras.layers.Reshape(shape_before_flattening)(x)
        
        for i in range(self.n_layers_decoder):
            x = conv_t_layer =   tf.keras.layers.Conv2DTranspose(
                filters = self.decoder_conv_t_filters[i],
                kernel_size = self.decoder_conv_t_kernel_size[i],
                strides = self.decoder_conv_t_strides[i],
                padding = 'same',
                name = 'decoder_conv_t_' + str(i)
            )(x)
            
            if i < self.n_layers_decoder - 1:
                if self.use_batch_norm:                           ### The order of layers is opposite to AutoEncoder
                    x = tf.keras.layers.BatchNormalization()(x)   ###     AE: LeakyReLU -> BatchNorm
                x = tf.keras.layers.LeakyReLU()(x)                ###      VAE: BatchNorm -> LeakyReLU                
                if self.use_dropout:
                    x = tf.keras.layers.Dropout(rate=0.25)(x)
            else:
                x = tf.keras.layers.Activation('sigmoid')(x)
       
        decoder_output = x
        self.decoder = tf.keras.models.Model(decoder_input, decoder_output, name='decoder')  ### added (name)
        
        ### THE FULL AUTOENCODER
        self.model = VAEModel(self.encoder, self.decoder, self.r_loss_factor)
        
        
    def save(self, folder):
        self.save_params(os.path.join(folder, 'params.pkl'))
        self.save_weights(folder)


    @staticmethod
    def load(folder, epoch=None):  # VariationalAutoEncoder.load(folder)
        params = VariationalAutoEncoder.load_params(os.path.join(folder, 'params.pkl'))
        VAE = VariationalAutoEncoder(*params)
        if epoch is None:
            VAE.load_weights(folder)
        else:
            VAE.load_weights(folder, epoch-1)
            VAE.epoch = epoch
        return VAE

        
    def save_params(self, filepath):
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        with open(filepath, 'wb') as f:
            pickle.dump([
                self.input_dim,
                self.encoder_conv_filters,
                self.encoder_conv_kernel_size,
                self.encoder_conv_strides,
                self.decoder_conv_t_filters,
                self.decoder_conv_t_kernel_size,
                self.decoder_conv_t_strides,
                self.z_dim,
                self.r_loss_factor,
                self.use_batch_norm,
                self.use_dropout,
                self.epoch
            ], f)


    @staticmethod
    def load_params(filepath):
        with open(filepath, 'rb') as f:
            params = pickle.load(f)
        return params


    def save_weights(self, folder, epoch=None):
        if epoch is None:
            self.save_model_weights(self.encoder, os.path.join(folder, f'weights/encoder-weights.h5'))
            self.save_model_weights(self.decoder, os.path.join(folder, f'weights/decoder-weights.h5'))
        else:
            self.save_model_weights(self.encoder, os.path.join(folder, f'weights/encoder-weights_{epoch}.h5'))
            self.save_model_weights(self.decoder, os.path.join(folder, f'weights/decoder-weights_{epoch}.h5'))


    def save_model_weights(self, model, filepath):
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        model.save_weights(filepath)


    def load_weights(self, folder, epoch=None):
        if epoch is None:
            self.encoder.load_weights(os.path.join(folder, f'weights/encoder-weights.h5'))
            self.decoder.load_weights(os.path.join(folder, f'weights/decoder-weights.h5'))
        else:
            self.encoder.load_weights(os.path.join(folder, f'weights/encoder-weights_{epoch}.h5'))
            self.decoder.load_weights(os.path.join(folder, f'weights/decoder-weights_{epoch}.h5'))


    def save_images(self, imgs, filepath):
        z_mean, z_log_var, z = self.encoder.predict(imgs)
        reconst_imgs = self.decoder.predict(z)
        txts = [ f'{p[0]:.3f}, {p[1]:.3f}' for p in z ]
        AutoEncoder.showImages(imgs, reconst_imgs, txts, 1.4, 1.4, 0.5, filepath)
      

    def compile(self, learning_rate):
        self.learning_rate = learning_rate
        optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
        self.model.compile(optimizer=optimizer)     # CAUTION!!!: loss(y_true, y_pred) function is not specified.
        
        
    def train_with_fit(
            self,
            x_train,
            batch_size,
            epochs,
            run_folder='run/'
    ):
        history = self.model.fit(
            x_train,
            x_train,
            batch_size = batch_size,
            shuffle=True,
            initial_epoch = self.epoch,
            epochs = epochs
        )
        if (self.epoch < epochs):
            self.epoch = epochs

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(run_folder, self.epoch-1)
        
        return history


    def train_generator_with_fit(
            self,
            data_flow,
            epochs,
            run_folder='run/'
    ):
        history = self.model.fit(
            data_flow,
            initial_epoch = self.epoch,
            epochs = epochs
        )
        if (self.epoch < epochs):
            self.epoch = epochs

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(run_folder, self.epoch-1)
        
        return history


    def train_tf(
            self,
            x_train,
            batch_size = 32,
            epochs = 10,
            shuffle = False,
            run_folder = 'run/',
            optimizer = None,
            save_epoch_interval = 100,
            validation_data = None
    ):
        start_time = datetime.datetime.now()
        steps = x_train.shape[0] // batch_size

        total_losses = []
        reconstruction_losses = []
        kl_losses = []

        val_total_losses = []
        val_reconstruction_losses = []
        val_kl_losses = []

        for epoch in range(self.epoch, epochs):
            epoch_loss = 0
            indices = tf.range(x_train.shape[0], dtype=tf.int32)
            if shuffle:
                indices = tf.random.shuffle(indices)
            x_ = x_train[indices]

            step_total_losses = []
            step_reconstruction_losses = []
            step_kl_losses = []
            for step in range(steps):
                start = batch_size * step
                end = start + batch_size

                total_loss, reconstruction_loss, kl_loss, grads = self.model.compute_loss_and_grads(x_[start:end])
                optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
                
                step_total_losses.append(np.mean(total_loss))
                step_reconstruction_losses.append(np.mean(reconstruction_loss))
                step_kl_losses.append(np.mean(kl_loss))
            
            epoch_total_loss = np.mean(step_total_losses)
            epoch_reconstruction_loss = np.mean(step_reconstruction_losses)
            epoch_kl_loss = np.mean(step_kl_losses)

            total_losses.append(epoch_total_loss)
            reconstruction_losses.append(epoch_reconstruction_loss)
            kl_losses.append(epoch_kl_loss)

            val_str = ''
            if not validation_data is None:
                x_val = validation_data
                tl, rl, kl = self.model.loss_fn(x_val)
                val_tl = np.mean(tl)
                val_rl = np.mean(rl)
                val_kl = np.mean(kl)
                val_total_losses.append(val_tl)
                val_reconstruction_losses.append(val_rl)
                val_kl_losses.append(val_kl)
                val_str = f'val loss total {val_tl:.3f} reconstruction {val_rl:.3f} kl {val_kl:.3f} '

            if (epoch+1) % save_epoch_interval == 0 and run_folder != None:
                self.save(run_folder)
                self.save_weights(run_folder, self.epoch)

            elapsed_time = datetime.datetime.now() - start_time
            print(f'{epoch+1}/{epochs} {steps} loss: total {epoch_total_loss:.3f} reconstruction {epoch_reconstruction_loss:.3f} kl {epoch_kl_loss:.3f} {val_str}{elapsed_time}')

            self.epoch += 1

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(run_folder, self.epoch-1)

        dic = { 'loss' : total_losses, 'reconstruction_loss' : reconstruction_losses, 'kl_loss' : kl_losses }
        if not validation_data is None:
            dic['val_loss'] = val_total_losses
            dic['val_reconstruction_loss'] = val_reconstruction_losses
            dic['val_kl_loss'] = val_kl_losses

        return dic
            

    def train_tf_generator(
            self,
            data_flow,
            epochs = 10,
            run_folder = 'run/',
            optimizer = None,
            save_epoch_interval = 100,
            validation_data_flow = None
    ):
        start_time = datetime.datetime.now()
        steps = len(data_flow)

        total_losses = []
        reconstruction_losses = []
        kl_losses = []

        val_total_losses = []
        val_reconstruction_losses = []
        val_kl_losses = []

        for epoch in range(self.epoch, epochs):
            epoch_loss = 0

            step_total_losses = []
            step_reconstruction_losses = []
            step_kl_losses = []

            for step in range(steps):
                x, _ = next(data_flow)

                total_loss, reconstruction_loss, kl_loss, grads = self.model.compute_loss_and_grads(x)
                optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
                
                step_total_losses.append(np.mean(total_loss))
                step_reconstruction_losses.append(np.mean(reconstruction_loss))
                step_kl_losses.append(np.mean(kl_loss))
            
            epoch_total_loss = np.mean(step_total_losses)
            epoch_reconstruction_loss = np.mean(step_reconstruction_losses)
            epoch_kl_loss = np.mean(step_kl_losses)

            total_losses.append(epoch_total_loss)
            reconstruction_losses.append(epoch_reconstruction_loss)
            kl_losses.append(epoch_kl_loss)

            val_str = ''
            if not validation_data_flow is None:
                step_val_tl = []
                step_val_rl = []
                step_val_kl = []
                for i in range(len(validation_data_flow)):
                    x, _ = next(validation_data_flow)
                    tl, rl, kl = self.model.loss_fn(x)
                    step_val_tl.append(np.mean(tl))
                    step_val_rl.append(np.mean(rl))
                    step_val_kl.append(np.mean(kl))
                val_tl = np.mean(step_val_tl)
                val_rl = np.mean(step_val_rl)
                val_kl = np.mean(step_val_kl)
                val_total_losses.append(val_tl)
                val_reconstruction_losses.append(val_rl)
                val_kl_losses.append(val_kl)
                val_str = f'val loss total {val_tl:.3f} reconstruction {val_rl:.3f} kl {val_kl:.3f} '

            if (epoch+1) % save_epoch_interval == 0 and run_folder != None:
                self.save(run_folder)
                self.save_weights(run_folder, self.epoch)

            elapsed_time = datetime.datetime.now() - start_time
            print(f'{epoch+1}/{epochs} {steps} loss: total {epoch_total_loss:.3f} reconstruction {epoch_reconstruction_loss:.3f} kl {epoch_kl_loss:.3f} {val_str}{elapsed_time}')

            self.epoch += 1

        if run_folder != None:
            self.save(run_folder)
            self.save_weights(run_folder, self.epoch-1)

        dic = { 'loss' : total_losses, 'reconstruction_loss' : reconstruction_losses, 'kl_loss' : kl_losses }
        if not validation_data_flow is None:
            dic['val_loss'] = val_total_losses
            dic['val_reconstruction_loss'] = val_reconstruction_losses
            dic['val_kl_loss'] = val_kl_losses

        return dic


    @staticmethod
    def showImages(imgs1, imgs2, txts, w, h, vskip=0.5, filepath=None):
        n = len(imgs1)
        fig, ax = plt.subplots(2, n, figsize=(w * n, (2+vskip) * h))
        for i in range(n):
            if n == 1:
                axis = ax[0]
            else:
                axis = ax[0][i]
            img = imgs1[i].squeeze()
            axis.imshow(img, cmap='gray_r')
            axis.axis('off')

            axis.text(0.5, -0.35, txts[i], fontsize=10, ha='center', transform=axis.transAxes)

            if n == 1:
                axis = ax[1]
            else:
                axis = ax[1][i]
            img2 = imgs2[i].squeeze()
            axis.imshow(img2, cmap='gray_r')
            axis.axis('off')

        if not filepath is None:
            dpath, fname = os.path.split(filepath)
            if dpath != '' and not os.path.exists(dpath):
                os.makedirs(dpath)
            fig.savefig(filepath, dpi=600)
            plt.close()
        else:
            plt.show()

    @staticmethod
    def plot_history(vals, labels):
        colors = ['red', 'blue', 'green', 'orange', 'black', 'pink']
        n = len(vals)
        fig, ax = plt.subplots(1, 1, figsize=(9,4))
        for i in range(n):
            ax.plot(vals[i], c=colors[i], label=labels[i])
        ax.legend(loc='upper right')
        ax.set_xlabel('epochs')
        # ax[0].set_ylabel('loss')
        
        plt.show()

Preparing MNIST dataset

MNIST データセットを用意する

In [9]:
%tensorflow_version 2.x

import tensorflow as tf
import numpy as np

print(tf.__version__)
2.7.0
In [10]:
# prepare data
(x_train_raw, y_train_raw), (x_test_raw, y_test_raw) = tf.keras.datasets.mnist.load_data()
print(x_train_raw.shape)
print(y_train_raw.shape)
print(x_test_raw.shape)
print(y_test_raw.shape)
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)
In [11]:
x_train = x_train_raw.reshape(x_train_raw.shape+(1,)).astype('float32') / 255.0
x_test = x_test_raw.reshape(x_test_raw.shape+(1,)).astype('float32') / 255.0
print(x_train.shape)
print(x_test.shape)
(60000, 28, 28, 1)
(10000, 28, 28, 1)
In [12]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

N = 10
selected_indices = np.random.choice(x_train_raw.shape[0], N)

fig, ax = plt.subplots(1, N, figsize=(2.8 * N, 2.8))
for i in range(N):
    ax[i].imshow(x_train_raw[selected_indices[i]],cmap='gray')
    ax[i].axis('off')

plt.show()

Definition of Neural Network Model

Use the VariationalAutoEndoer class downloaded from nw.tsuda.ac.jp.

ニューラルネットワーク・モデルを定義する

nw.tsuda.ac.jp からダウンロードした VariationalAutoEncoder クラスを使う。

In [13]:
from nw.VariationalAutoEncoder import VariationalAutoEncoder

vae = VariationalAutoEncoder(
    input_dim = (28, 28, 1),
    encoder_conv_filters = [32, 64, 64, 64],
    encoder_conv_kernel_size = [3, 3, 3, 3],
    encoder_conv_strides = [1, 2, 2, 1],
    decoder_conv_t_filters = [64, 64, 32, 1],
    decoder_conv_t_kernel_size = [3, 3, 3, 3],
    decoder_conv_t_strides = [1, 2, 2, 1],
    z_dim = 2,
    r_loss_factor = 1000    
)
In [14]:
vae.encoder.summary()
Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 encoder_input (InputLayer)     [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 encoder_conv_0 (Conv2D)        (None, 28, 28, 32)   320         ['encoder_input[0][0]']          
                                                                                                  
 leaky_re_lu (LeakyReLU)        (None, 28, 28, 32)   0           ['encoder_conv_0[0][0]']         
                                                                                                  
 encoder_conv_1 (Conv2D)        (None, 14, 14, 64)   18496       ['leaky_re_lu[0][0]']            
                                                                                                  
 leaky_re_lu_1 (LeakyReLU)      (None, 14, 14, 64)   0           ['encoder_conv_1[0][0]']         
                                                                                                  
 encoder_conv_2 (Conv2D)        (None, 7, 7, 64)     36928       ['leaky_re_lu_1[0][0]']          
                                                                                                  
 leaky_re_lu_2 (LeakyReLU)      (None, 7, 7, 64)     0           ['encoder_conv_2[0][0]']         
                                                                                                  
 encoder_conv_3 (Conv2D)        (None, 7, 7, 64)     36928       ['leaky_re_lu_2[0][0]']          
                                                                                                  
 leaky_re_lu_3 (LeakyReLU)      (None, 7, 7, 64)     0           ['encoder_conv_3[0][0]']         
                                                                                                  
 flatten (Flatten)              (None, 3136)         0           ['leaky_re_lu_3[0][0]']          
                                                                                                  
 mu (Dense)                     (None, 2)            6274        ['flatten[0][0]']                
                                                                                                  
 log_var (Dense)                (None, 2)            6274        ['flatten[0][0]']                
                                                                                                  
 encoder_output (Sampling)      (None, 2)            0           ['mu[0][0]',                     
                                                                  'log_var[0][0]']                
                                                                                                  
==================================================================================================
Total params: 105,220
Trainable params: 105,220
Non-trainable params: 0
__________________________________________________________________________________________________
In [15]:
vae.decoder.summary()
Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 decoder_input (InputLayer)  [(None, 2)]               0         
                                                                 
 dense (Dense)               (None, 3136)              9408      
                                                                 
 reshape (Reshape)           (None, 7, 7, 64)          0         
                                                                 
 decoder_conv_t_0 (Conv2DTra  (None, 7, 7, 64)         36928     
 nspose)                                                         
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 7, 7, 64)          0         
                                                                 
 decoder_conv_t_1 (Conv2DTra  (None, 14, 14, 64)       36928     
 nspose)                                                         
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 14, 14, 64)        0         
                                                                 
 decoder_conv_t_2 (Conv2DTra  (None, 28, 28, 32)       18464     
 nspose)                                                         
                                                                 
 leaky_re_lu_6 (LeakyReLU)   (None, 28, 28, 32)        0         
                                                                 
 decoder_conv_t_3 (Conv2DTra  (None, 28, 28, 1)        289       
 nspose)                                                         
                                                                 
 activation (Activation)     (None, 28, 28, 1)         0         
                                                                 
=================================================================
Total params: 102,017
Trainable params: 102,017
Non-trainable params: 0
_________________________________________________________________

Training

Train in 3 ways.

[Caution] Note that if you call the save_image() function in the Training (2), (3) below, encoder.predict() and decoder.predict() will work and the execution will be slow.

学習する

3通りで学習する。

[注意] 以下の学習 (2), (3) の途中で save_images()関数を呼び出すと、 encoder.predict()decoder.predict() が動作して、実行が非常に遅くなるので注意すること。

In [16]:
MAX_EPOCHS = 200
In [17]:
learning_rate = 0.0005

(1) Training: Use vae.model.fit()

Note that the loss function is not specified at the call of vae.model.compile() function. Since it cannot be calculated simply using y_true and y_pred, the train_step() function of the VAEModel class called from fit() is used to find loss and gradients and train them. The self.optimizer of the VAEModel class referenced in the train_step() function is the optimizer given by the compile() function.

(1) 学習する: vae.model.fit() を使う

vae.model.compile() 関数の呼び出しにおいて、loss関数を指定しないことに注意が必要である。 y_truey_pred を使って単純に計算できないので、fit() から呼び出される VAEModel クラスの train_step() 関数でlossとgradientsを求めて、trainingする。 train_step() 関数の中で参照される VAEModel クラスの self.optimizercompile() 関数で与えられた optimizer である。

In [18]:
import os
save_path1 = '/content/drive/MyDrive/ColabRun/VAE01'
In [19]:
optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate)
vae.model.compile(optimizer=optimizer)
In [20]:
# まず、少ない回数 training してみる

history = vae.train_with_fit(
    x_train,
    batch_size = 32,
    epochs = 3,
    run_folder = save_path1
)
Epoch 1/3
1875/1875 [==============================] - 12s 5ms/step - loss: 58.4073 - reconstruction_loss: 55.1287 - kl_loss: 3.2786
Epoch 2/3
1875/1875 [==============================] - 9s 5ms/step - loss: 51.5705 - reconstruction_loss: 47.5608 - kl_loss: 4.0097
Epoch 3/3
1875/1875 [==============================] - 9s 5ms/step - loss: 50.0829 - reconstruction_loss: 45.7901 - kl_loss: 4.2928
In [21]:
print(history.history)
{'loss': [55.75687789916992, 51.232696533203125, 51.035667419433594], 'reconstruction_loss': [51.87708282470703, 46.880558013916016, 46.8510856628418], 'kl_loss': [3.879793643951416, 4.352138042449951, 4.184581756591797]}
In [22]:
print(history.history.keys())
dict_keys(['loss', 'reconstruction_loss', 'kl_loss'])
In [23]:
#tmp = history.history['loss']

#print(len(tmp))
#print(len(tmp[0]))
In [24]:
loss1_1 = history.history['loss']
rloss1_1 = history.history['reconstruction_loss']
kloss1_1 = history.history['kl_loss']

Training in addition

Load the saved parameters and model weights, and try training further.

追加の学習

保存してあるパラメータと、モデルの重みを読み込んで、さらにtraining してみる。

In [25]:
# Load the saved parameters and weights.
# 保存してある学習結果をロードする。

vae_work = VariationalAutoEncoder.load(save_path1)

# Display the epoch count of the model.
# training のepoch回数を表示する。

print(vae_work.epoch)
3
In [26]:
# Training in addition
# 追加で training する。

vae_work.model.compile(optimizer)

history2 = vae_work.train_with_fit(
    x_train,
    batch_size = 32,
    epochs = MAX_EPOCHS,
    run_folder = save_path1
)
Epoch 4/200
1875/1875 [==============================] - 10s 5ms/step - loss: 49.8648 - reconstruction_loss: 45.5221 - kl_loss: 4.3427
Epoch 5/200
1875/1875 [==============================] - 10s 5ms/step - loss: 48.5465 - reconstruction_loss: 43.9828 - kl_loss: 4.5637
Epoch 6/200
1875/1875 [==============================] - 10s 5ms/step - loss: 48.0179 - reconstruction_loss: 43.3728 - kl_loss: 4.6451
Epoch 7/200
1875/1875 [==============================] - 10s 5ms/step - loss: 47.6538 - reconstruction_loss: 42.9425 - kl_loss: 4.7113
Epoch 8/200
1875/1875 [==============================] - 9s 5ms/step - loss: 47.3654 - reconstruction_loss: 42.6093 - kl_loss: 4.7561
Epoch 9/200
1875/1875 [==============================] - 9s 5ms/step - loss: 47.0985 - reconstruction_loss: 42.2911 - kl_loss: 4.8074
Epoch 10/200
1875/1875 [==============================] - 10s 5ms/step - loss: 46.9123 - reconstruction_loss: 42.0577 - kl_loss: 4.8546
Epoch 11/200
1875/1875 [==============================] - 9s 5ms/step - loss: 46.7336 - reconstruction_loss: 41.8435 - kl_loss: 4.8901
Epoch 12/200
1875/1875 [==============================] - 9s 5ms/step - loss: 46.5381 - reconstruction_loss: 41.6378 - kl_loss: 4.9003
Epoch 13/200
1875/1875 [==============================] - 10s 5ms/step - loss: 46.3920 - reconstruction_loss: 41.4584 - kl_loss: 4.9336
Epoch 14/200
1875/1875 [==============================] - 10s 5ms/step - loss: 46.2554 - reconstruction_loss: 41.2922 - kl_loss: 4.9632
Epoch 15/200
1875/1875 [==============================] - 9s 5ms/step - loss: 46.1605 - reconstruction_loss: 41.1937 - kl_loss: 4.9669
Epoch 16/200
1875/1875 [==============================] - 10s 5ms/step - loss: 46.0637 - reconstruction_loss: 41.0784 - kl_loss: 4.9853
Epoch 17/200
1875/1875 [==============================] - 10s 5ms/step - loss: 45.9436 - reconstruction_loss: 40.9368 - kl_loss: 5.0069
Epoch 18/200
1875/1875 [==============================] - 9s 5ms/step - loss: 45.8828 - reconstruction_loss: 40.8677 - kl_loss: 5.0151
Epoch 19/200
1875/1875 [==============================] - 9s 5ms/step - loss: 45.8291 - reconstruction_loss: 40.8016 - kl_loss: 5.0275
Epoch 20/200
1875/1875 [==============================] - 9s 5ms/step - loss: 45.7408 - reconstruction_loss: 40.6919 - kl_loss: 5.0488
Epoch 21/200
1875/1875 [==============================] - 9s 5ms/step - loss: 45.6722 - reconstruction_loss: 40.6060 - kl_loss: 5.0662
Epoch 22/200
1875/1875 [==============================] - 9s 5ms/step - loss: 45.6056 - reconstruction_loss: 40.5360 - kl_loss: 5.0696
Epoch 23/200
1875/1875 [==============================] - 9s 5ms/step - loss: 45.5273 - reconstruction_loss: 40.4394 - kl_loss: 5.0879
Epoch 24/200
1875/1875 [==============================] - 10s 5ms/step - loss: 45.4996 - reconstruction_loss: 40.4220 - kl_loss: 5.0776
Epoch 25/200
1875/1875 [==============================] - 9s 5ms/step - loss: 45.4448 - reconstruction_loss: 40.3568 - kl_loss: 5.0881
Epoch 26/200
1875/1875 [==============================] - 9s 5ms/step - loss: 45.3665 - reconstruction_loss: 40.2728 - kl_loss: 5.0937
Epoch 27/200
1875/1875 [==============================] - 10s 5ms/step - loss: 45.3241 - reconstruction_loss: 40.2183 - kl_loss: 5.1058
Epoch 28/200
1875/1875 [==============================] - 10s 5ms/step - loss: 45.2551 - reconstruction_loss: 40.1323 - kl_loss: 5.1228
Epoch 29/200
1875/1875 [==============================] - 10s 5ms/step - loss: 45.2791 - reconstruction_loss: 40.1603 - kl_loss: 5.1188
Epoch 30/200
1875/1875 [==============================] - 10s 5ms/step - loss: 45.2380 - reconstruction_loss: 40.1018 - kl_loss: 5.1362
Epoch 31/200
1875/1875 [==============================] - 10s 5ms/step - loss: 45.1768 - reconstruction_loss: 40.0453 - kl_loss: 5.1314
Epoch 32/200
1875/1875 [==============================] - 10s 5ms/step - loss: 45.1080 - reconstruction_loss: 39.9605 - kl_loss: 5.1475
Epoch 33/200
1875/1875 [==============================] - 9s 5ms/step - loss: 45.0962 - reconstruction_loss: 39.9452 - kl_loss: 5.1510
Epoch 34/200
1875/1875 [==============================] - 10s 5ms/step - loss: 45.0500 - reconstruction_loss: 39.8972 - kl_loss: 5.1528
Epoch 35/200
1875/1875 [==============================] - 10s 5ms/step - loss: 45.0309 - reconstruction_loss: 39.8673 - kl_loss: 5.1636
Epoch 36/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.9416 - reconstruction_loss: 39.7782 - kl_loss: 5.1634
Epoch 37/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.9534 - reconstruction_loss: 39.7739 - kl_loss: 5.1795
Epoch 38/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.9226 - reconstruction_loss: 39.7362 - kl_loss: 5.1864
Epoch 39/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.8763 - reconstruction_loss: 39.6899 - kl_loss: 5.1863
Epoch 40/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.8352 - reconstruction_loss: 39.6541 - kl_loss: 5.1811
Epoch 41/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.8205 - reconstruction_loss: 39.6282 - kl_loss: 5.1923
Epoch 42/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.8063 - reconstruction_loss: 39.6155 - kl_loss: 5.1908
Epoch 43/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.7581 - reconstruction_loss: 39.5525 - kl_loss: 5.2056
Epoch 44/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.7326 - reconstruction_loss: 39.5335 - kl_loss: 5.1991
Epoch 45/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.7131 - reconstruction_loss: 39.5036 - kl_loss: 5.2095
Epoch 46/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.7038 - reconstruction_loss: 39.4889 - kl_loss: 5.2149
Epoch 47/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.6709 - reconstruction_loss: 39.4595 - kl_loss: 5.2114
Epoch 48/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.6338 - reconstruction_loss: 39.4122 - kl_loss: 5.2215
Epoch 49/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.6278 - reconstruction_loss: 39.3961 - kl_loss: 5.2317
Epoch 50/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.5628 - reconstruction_loss: 39.3371 - kl_loss: 5.2257
Epoch 51/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.5658 - reconstruction_loss: 39.3284 - kl_loss: 5.2374
Epoch 52/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.5387 - reconstruction_loss: 39.3153 - kl_loss: 5.2234
Epoch 53/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.5424 - reconstruction_loss: 39.3081 - kl_loss: 5.2343
Epoch 54/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.5090 - reconstruction_loss: 39.2641 - kl_loss: 5.2449
Epoch 55/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.4785 - reconstruction_loss: 39.2307 - kl_loss: 5.2478
Epoch 56/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.4971 - reconstruction_loss: 39.2465 - kl_loss: 5.2506
Epoch 57/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.4685 - reconstruction_loss: 39.2069 - kl_loss: 5.2615
Epoch 58/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.4501 - reconstruction_loss: 39.1884 - kl_loss: 5.2617
Epoch 59/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.4471 - reconstruction_loss: 39.1896 - kl_loss: 5.2575
Epoch 60/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.3876 - reconstruction_loss: 39.1207 - kl_loss: 5.2668
Epoch 61/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.3903 - reconstruction_loss: 39.1179 - kl_loss: 5.2725
Epoch 62/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.3935 - reconstruction_loss: 39.1193 - kl_loss: 5.2742
Epoch 63/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.3373 - reconstruction_loss: 39.0673 - kl_loss: 5.2700
Epoch 64/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.3494 - reconstruction_loss: 39.0656 - kl_loss: 5.2839
Epoch 65/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.3091 - reconstruction_loss: 39.0228 - kl_loss: 5.2863
Epoch 66/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.3122 - reconstruction_loss: 39.0293 - kl_loss: 5.2829
Epoch 67/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.2802 - reconstruction_loss: 38.9922 - kl_loss: 5.2880
Epoch 68/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.2590 - reconstruction_loss: 38.9656 - kl_loss: 5.2934
Epoch 69/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.2646 - reconstruction_loss: 38.9654 - kl_loss: 5.2991
Epoch 70/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.2537 - reconstruction_loss: 38.9552 - kl_loss: 5.2984
Epoch 71/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.2393 - reconstruction_loss: 38.9393 - kl_loss: 5.3000
Epoch 72/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.2087 - reconstruction_loss: 38.8998 - kl_loss: 5.3089
Epoch 73/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.2166 - reconstruction_loss: 38.9082 - kl_loss: 5.3084
Epoch 74/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.1914 - reconstruction_loss: 38.8846 - kl_loss: 5.3068
Epoch 75/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.1735 - reconstruction_loss: 38.8549 - kl_loss: 5.3186
Epoch 76/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.1274 - reconstruction_loss: 38.8187 - kl_loss: 5.3087
Epoch 77/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.1572 - reconstruction_loss: 38.8292 - kl_loss: 5.3280
Epoch 78/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.1401 - reconstruction_loss: 38.8232 - kl_loss: 5.3169
Epoch 79/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.0833 - reconstruction_loss: 38.7594 - kl_loss: 5.3238
Epoch 80/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.1309 - reconstruction_loss: 38.8012 - kl_loss: 5.3298
Epoch 81/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.1047 - reconstruction_loss: 38.7807 - kl_loss: 5.3239
Epoch 82/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.1123 - reconstruction_loss: 38.7756 - kl_loss: 5.3367
Epoch 83/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.0822 - reconstruction_loss: 38.7454 - kl_loss: 5.3368
Epoch 84/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.0395 - reconstruction_loss: 38.7090 - kl_loss: 5.3305
Epoch 85/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.0405 - reconstruction_loss: 38.6956 - kl_loss: 5.3449
Epoch 86/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.0354 - reconstruction_loss: 38.6995 - kl_loss: 5.3359
Epoch 87/200
1875/1875 [==============================] - 10s 5ms/step - loss: 44.0608 - reconstruction_loss: 38.7157 - kl_loss: 5.3451
Epoch 88/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.9911 - reconstruction_loss: 38.6527 - kl_loss: 5.3384
Epoch 89/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.9644 - reconstruction_loss: 38.6286 - kl_loss: 5.3357
Epoch 90/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.9787 - reconstruction_loss: 38.6390 - kl_loss: 5.3397
Epoch 91/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.9799 - reconstruction_loss: 38.6322 - kl_loss: 5.3477
Epoch 92/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.9351 - reconstruction_loss: 38.6002 - kl_loss: 5.3349
Epoch 93/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.9710 - reconstruction_loss: 38.6108 - kl_loss: 5.3603
Epoch 94/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.9590 - reconstruction_loss: 38.6000 - kl_loss: 5.3590
Epoch 95/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.9179 - reconstruction_loss: 38.5682 - kl_loss: 5.3496
Epoch 96/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.9087 - reconstruction_loss: 38.5479 - kl_loss: 5.3607
Epoch 97/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.9272 - reconstruction_loss: 38.5629 - kl_loss: 5.3643
Epoch 98/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.9077 - reconstruction_loss: 38.5376 - kl_loss: 5.3701
Epoch 99/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.8838 - reconstruction_loss: 38.5157 - kl_loss: 5.3681
Epoch 100/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.8887 - reconstruction_loss: 38.5192 - kl_loss: 5.3695
Epoch 101/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.8478 - reconstruction_loss: 38.4825 - kl_loss: 5.3653
Epoch 102/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.8872 - reconstruction_loss: 38.5031 - kl_loss: 5.3841
Epoch 103/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.8766 - reconstruction_loss: 38.5091 - kl_loss: 5.3675
Epoch 104/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.8377 - reconstruction_loss: 38.4672 - kl_loss: 5.3705
Epoch 105/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.8313 - reconstruction_loss: 38.4539 - kl_loss: 5.3774
Epoch 106/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.8531 - reconstruction_loss: 38.4668 - kl_loss: 5.3864
Epoch 107/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.7781 - reconstruction_loss: 38.4015 - kl_loss: 5.3765
Epoch 108/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.8223 - reconstruction_loss: 38.4458 - kl_loss: 5.3765
Epoch 109/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.7960 - reconstruction_loss: 38.4115 - kl_loss: 5.3844
Epoch 110/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.8019 - reconstruction_loss: 38.4035 - kl_loss: 5.3984
Epoch 111/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.8003 - reconstruction_loss: 38.4065 - kl_loss: 5.3938
Epoch 112/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.8260 - reconstruction_loss: 38.4250 - kl_loss: 5.4010
Epoch 113/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.7902 - reconstruction_loss: 38.3985 - kl_loss: 5.3918
Epoch 114/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.7565 - reconstruction_loss: 38.3619 - kl_loss: 5.3946
Epoch 115/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.7997 - reconstruction_loss: 38.4051 - kl_loss: 5.3946
Epoch 116/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.7543 - reconstruction_loss: 38.3572 - kl_loss: 5.3971
Epoch 117/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.7634 - reconstruction_loss: 38.3714 - kl_loss: 5.3920
Epoch 118/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.7563 - reconstruction_loss: 38.3530 - kl_loss: 5.4033
Epoch 119/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.7339 - reconstruction_loss: 38.3282 - kl_loss: 5.4057
Epoch 120/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.7077 - reconstruction_loss: 38.3175 - kl_loss: 5.3902
Epoch 121/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6943 - reconstruction_loss: 38.2902 - kl_loss: 5.4041
Epoch 122/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6897 - reconstruction_loss: 38.2903 - kl_loss: 5.3994
Epoch 123/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.7193 - reconstruction_loss: 38.3132 - kl_loss: 5.4060
Epoch 124/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6935 - reconstruction_loss: 38.2936 - kl_loss: 5.3999
Epoch 125/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6781 - reconstruction_loss: 38.2736 - kl_loss: 5.4045
Epoch 126/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6816 - reconstruction_loss: 38.2680 - kl_loss: 5.4135
Epoch 127/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6638 - reconstruction_loss: 38.2497 - kl_loss: 5.4141
Epoch 128/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6708 - reconstruction_loss: 38.2588 - kl_loss: 5.4120
Epoch 129/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6649 - reconstruction_loss: 38.2518 - kl_loss: 5.4131
Epoch 130/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6477 - reconstruction_loss: 38.2364 - kl_loss: 5.4113
Epoch 131/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6556 - reconstruction_loss: 38.2362 - kl_loss: 5.4194
Epoch 132/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6382 - reconstruction_loss: 38.2113 - kl_loss: 5.4269
Epoch 133/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6294 - reconstruction_loss: 38.2070 - kl_loss: 5.4224
Epoch 134/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6449 - reconstruction_loss: 38.2194 - kl_loss: 5.4255
Epoch 135/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6073 - reconstruction_loss: 38.1852 - kl_loss: 5.4221
Epoch 136/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6422 - reconstruction_loss: 38.2164 - kl_loss: 5.4259
Epoch 137/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5685 - reconstruction_loss: 38.1501 - kl_loss: 5.4184
Epoch 138/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.6314 - reconstruction_loss: 38.1939 - kl_loss: 5.4375
Epoch 139/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5616 - reconstruction_loss: 38.1262 - kl_loss: 5.4354
Epoch 140/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5970 - reconstruction_loss: 38.1678 - kl_loss: 5.4292
Epoch 141/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5744 - reconstruction_loss: 38.1391 - kl_loss: 5.4353
Epoch 142/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5831 - reconstruction_loss: 38.1496 - kl_loss: 5.4335
Epoch 143/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5867 - reconstruction_loss: 38.1375 - kl_loss: 5.4491
Epoch 144/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5403 - reconstruction_loss: 38.1161 - kl_loss: 5.4242
Epoch 145/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5645 - reconstruction_loss: 38.1276 - kl_loss: 5.4369
Epoch 146/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5185 - reconstruction_loss: 38.0682 - kl_loss: 5.4503
Epoch 147/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4965 - reconstruction_loss: 38.0613 - kl_loss: 5.4352
Epoch 148/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5425 - reconstruction_loss: 38.1036 - kl_loss: 5.4389
Epoch 149/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5263 - reconstruction_loss: 38.0834 - kl_loss: 5.4429
Epoch 150/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5132 - reconstruction_loss: 38.0750 - kl_loss: 5.4382
Epoch 151/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5202 - reconstruction_loss: 38.0786 - kl_loss: 5.4417
Epoch 152/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4944 - reconstruction_loss: 38.0514 - kl_loss: 5.4430
Epoch 153/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5075 - reconstruction_loss: 38.0543 - kl_loss: 5.4532
Epoch 154/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5126 - reconstruction_loss: 38.0520 - kl_loss: 5.4607
Epoch 155/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4627 - reconstruction_loss: 38.0185 - kl_loss: 5.4442
Epoch 156/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.5065 - reconstruction_loss: 38.0551 - kl_loss: 5.4514
Epoch 157/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4705 - reconstruction_loss: 38.0114 - kl_loss: 5.4592
Epoch 158/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4364 - reconstruction_loss: 37.9941 - kl_loss: 5.4423
Epoch 159/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4576 - reconstruction_loss: 37.9981 - kl_loss: 5.4595
Epoch 160/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4462 - reconstruction_loss: 37.9900 - kl_loss: 5.4562
Epoch 161/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4607 - reconstruction_loss: 37.9980 - kl_loss: 5.4627
Epoch 162/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4572 - reconstruction_loss: 37.9871 - kl_loss: 5.4701
Epoch 163/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3995 - reconstruction_loss: 37.9471 - kl_loss: 5.4524
Epoch 164/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4320 - reconstruction_loss: 37.9754 - kl_loss: 5.4566
Epoch 165/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4454 - reconstruction_loss: 37.9726 - kl_loss: 5.4728
Epoch 166/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4450 - reconstruction_loss: 37.9800 - kl_loss: 5.4650
Epoch 167/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4163 - reconstruction_loss: 37.9532 - kl_loss: 5.4632
Epoch 168/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4117 - reconstruction_loss: 37.9390 - kl_loss: 5.4726
Epoch 169/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4307 - reconstruction_loss: 37.9576 - kl_loss: 5.4731
Epoch 170/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4182 - reconstruction_loss: 37.9534 - kl_loss: 5.4649
Epoch 171/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3844 - reconstruction_loss: 37.9157 - kl_loss: 5.4687
Epoch 172/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4400 - reconstruction_loss: 37.9656 - kl_loss: 5.4744
Epoch 173/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4086 - reconstruction_loss: 37.9371 - kl_loss: 5.4715
Epoch 174/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.4037 - reconstruction_loss: 37.9290 - kl_loss: 5.4747
Epoch 175/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3655 - reconstruction_loss: 37.8892 - kl_loss: 5.4763
Epoch 176/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3934 - reconstruction_loss: 37.9160 - kl_loss: 5.4774
Epoch 177/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3601 - reconstruction_loss: 37.8833 - kl_loss: 5.4767
Epoch 178/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3450 - reconstruction_loss: 37.8623 - kl_loss: 5.4826
Epoch 179/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3546 - reconstruction_loss: 37.8703 - kl_loss: 5.4843
Epoch 180/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3659 - reconstruction_loss: 37.8991 - kl_loss: 5.4668
Epoch 181/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3438 - reconstruction_loss: 37.8470 - kl_loss: 5.4967
Epoch 182/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3575 - reconstruction_loss: 37.8747 - kl_loss: 5.4829
Epoch 183/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3240 - reconstruction_loss: 37.8332 - kl_loss: 5.4908
Epoch 184/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3059 - reconstruction_loss: 37.8316 - kl_loss: 5.4743
Epoch 185/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3094 - reconstruction_loss: 37.8254 - kl_loss: 5.4840
Epoch 186/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3536 - reconstruction_loss: 37.8561 - kl_loss: 5.4975
Epoch 187/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3179 - reconstruction_loss: 37.8290 - kl_loss: 5.4889
Epoch 188/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.2968 - reconstruction_loss: 37.8050 - kl_loss: 5.4918
Epoch 189/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.2858 - reconstruction_loss: 37.7874 - kl_loss: 5.4984
Epoch 190/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.2893 - reconstruction_loss: 37.7938 - kl_loss: 5.4955
Epoch 191/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3107 - reconstruction_loss: 37.8106 - kl_loss: 5.5000
Epoch 192/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.2927 - reconstruction_loss: 37.8019 - kl_loss: 5.4908
Epoch 193/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.3001 - reconstruction_loss: 37.7933 - kl_loss: 5.5068
Epoch 194/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.2894 - reconstruction_loss: 37.7958 - kl_loss: 5.4936
Epoch 195/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.2817 - reconstruction_loss: 37.7856 - kl_loss: 5.4961
Epoch 196/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.2595 - reconstruction_loss: 37.7628 - kl_loss: 5.4967
Epoch 197/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.2896 - reconstruction_loss: 37.7876 - kl_loss: 5.5020
Epoch 198/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.2766 - reconstruction_loss: 37.7708 - kl_loss: 5.5057
Epoch 199/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.2508 - reconstruction_loss: 37.7432 - kl_loss: 5.5076
Epoch 200/200
1875/1875 [==============================] - 10s 5ms/step - loss: 43.2551 - reconstruction_loss: 37.7514 - kl_loss: 5.5037
In [27]:
print(len(history2.history))
3
In [28]:
loss1_2 = history2.history['loss']
rloss1_2 = history2.history['reconstruction_loss']
kloss1_2 = history2.history['kl_loss']

loss1 = np.concatenate([loss1_1, loss1_2], axis=0)
rloss1 = np.concatenate([rloss1_1, rloss1_2], axis=0)
kloss1 = np.concatenate([kloss1_1, kloss1_2], axis=0)
In [29]:
VariationalAutoEncoder.plot_history([loss1, rloss1, kloss1], ['total_loss', 'reconstruct_loss', 'kl_loss'])

Validate Training results

Since the returned value of vae.decoder() is Tensor for the use of <code>@tf.function</code>, it needs to be converted to an array of numpy.

学習結果を検証する

@tf.function 宣言のためvae.decoder() の返り値は Tensor になっているので、numpy の配列に変換する必要がある。

In [30]:
selected_indices = np.random.choice(range(len(x_test)), 10)
selected_images = x_test[selected_indices]
In [31]:
z_mean, z_log_var, z = vae_work.encoder(selected_images)
reconst_images = vae_work.decoder(z).numpy()  # Convert Tensor to numpy array.

txts = [f'{p[0]:.3f}, {p[1]:.3f}' for p in z ]
In [32]:
%matplotlib inline

VariationalAutoEncoder.showImages(selected_images, reconst_images, txts, 1.4, 1.4)

(2) Training with tf.GradientTape() function.

Instead of using fit(), calculate the loss in your own train() function, find the gradients, and apply them to the variables.

The train_tf() function is speeding up by declaring <code>@tf.function</code> the compute_loss_and_grads() function.

(2) tf.GradientTape() 関数を使った学習

fit() 関数を使わずに、自分で記述した train() 関数内で loss を計算し、gradients を求めて、変数に適用する。

train_tf() 関数では、lossとgradientsの計算を行う compute_loss_and_grads() 関数を <code>@tf.function</code> 宣言することで高速化を図っている。

In [33]:
save_path2 = '/content/drive/MyDrive/ColabRun/VAE02/'
In [34]:
from nw.VariationalAutoEncoder import VariationalAutoEncoder

vae2 = VariationalAutoEncoder(
    input_dim = (28, 28, 1),
    encoder_conv_filters = [32, 64, 64, 64],
    encoder_conv_kernel_size = [3, 3, 3, 3],
    encoder_conv_strides = [1, 2, 2, 1],
    decoder_conv_t_filters = [64, 64, 32, 1],
    decoder_conv_t_kernel_size = [3, 3, 3, 3],
    decoder_conv_t_strides = [1, 2, 2, 1],
    z_dim = 2,
    r_loss_factor = 1000 
)
In [35]:
optimizer2 = tf.keras.optimizers.Adam(learning_rate=learning_rate)
In [36]:
log2_1 = vae2.train_tf(
    x_train,
    batch_size = 32,
    epochs = 3,
    shuffle=True,
    run_folder = save_path2,
    optimizer = optimizer2,
    save_epoch_interval=50,
    validation_data=x_test
)
1/3 1875 loss: total 58.770 reconstruction 55.546 kl 3.224 val loss total 52.817 reconstruction 49.169 kl 3.648 0:00:18.577035
2/3 1875 loss: total 51.869 reconstruction 47.964 kl 3.905 val loss total 50.912 reconstruction 46.834 kl 4.078 0:00:35.127926
3/3 1875 loss: total 50.510 reconstruction 46.300 kl 4.210 val loss total 49.969 reconstruction 45.687 kl 4.282 0:00:51.720397
In [37]:
print(log2_1.keys())

loss2_1 = log2_1['loss']
rloss2_1 = log2_1['reconstruction_loss']
kloss2_1 = log2_1['kl_loss']
val_loss2_1 = log2_1['val_loss']
val_rloss2_1 = log2_1['val_reconstruction_loss']
val_kloss2_1 = log2_1['val_kl_loss']
dict_keys(['loss', 'reconstruction_loss', 'kl_loss', 'val_loss', 'val_reconstruction_loss', 'val_kl_loss'])
In [38]:
# Load the saved parameters and weights.
# 保存したパラメータと重みを読み込む

vae2_work = VariationalAutoEncoder.load(save_path2)
print(vae2_work.epoch)
3
In [39]:
# Train in addition
# 追加で training する。

log2_2 = vae2_work.train_tf(
    x_train,
    batch_size = 32,
    epochs = MAX_EPOCHS,
    shuffle=True,
    run_folder = save_path2,
    optimizer = optimizer2,
    save_epoch_interval=50,
    validation_data=x_test
)
4/200 1875 loss: total 50.162 reconstruction 45.848 kl 4.314 val loss total 49.152 reconstruction 44.580 kl 4.572 0:00:17.277455
5/200 1875 loss: total 48.757 reconstruction 44.223 kl 4.534 val loss total 48.424 reconstruction 43.738 kl 4.686 0:00:33.858980
6/200 1875 loss: total 48.176 reconstruction 43.538 kl 4.638 val loss total 48.115 reconstruction 43.766 kl 4.349 0:00:50.429768
7/200 1875 loss: total 47.757 reconstruction 43.043 kl 4.714 val loss total 47.459 reconstruction 42.677 kl 4.781 0:01:06.866417
8/200 1875 loss: total 47.441 reconstruction 42.676 kl 4.765 val loss total 47.894 reconstruction 42.884 kl 5.011 0:01:23.446778
9/200 1875 loss: total 47.179 reconstruction 42.367 kl 4.813 val loss total 47.031 reconstruction 42.219 kl 4.812 0:01:40.115804
10/200 1875 loss: total 46.942 reconstruction 42.089 kl 4.853 val loss total 46.790 reconstruction 42.095 kl 4.695 0:01:56.976094
11/200 1875 loss: total 46.754 reconstruction 41.864 kl 4.890 val loss total 46.903 reconstruction 41.996 kl 4.907 0:02:13.604735
12/200 1875 loss: total 46.624 reconstruction 41.701 kl 4.923 val loss total 46.661 reconstruction 41.812 kl 4.850 0:02:30.399161
13/200 1875 loss: total 46.499 reconstruction 41.559 kl 4.940 val loss total 46.625 reconstruction 41.618 kl 5.008 0:02:47.078383
14/200 1875 loss: total 46.342 reconstruction 41.383 kl 4.959 val loss total 46.390 reconstruction 41.315 kl 5.076 0:03:03.818319
15/200 1875 loss: total 46.197 reconstruction 41.232 kl 4.965 val loss total 46.465 reconstruction 41.601 kl 4.864 0:03:20.633008
16/200 1875 loss: total 46.073 reconstruction 41.083 kl 4.990 val loss total 46.003 reconstruction 40.929 kl 5.074 0:03:37.440654
17/200 1875 loss: total 46.008 reconstruction 40.996 kl 5.012 val loss total 46.937 reconstruction 41.735 kl 5.203 0:03:54.220229
18/200 1875 loss: total 45.943 reconstruction 40.904 kl 5.039 val loss total 45.829 reconstruction 40.918 kl 4.911 0:04:10.945407
19/200 1875 loss: total 45.833 reconstruction 40.779 kl 5.054 val loss total 46.202 reconstruction 41.148 kl 5.054 0:04:27.852774
20/200 1875 loss: total 45.762 reconstruction 40.709 kl 5.053 val loss total 46.060 reconstruction 40.935 kl 5.124 0:04:44.701887
21/200 1875 loss: total 45.679 reconstruction 40.611 kl 5.069 val loss total 46.077 reconstruction 40.876 kl 5.201 0:05:01.362951
22/200 1875 loss: total 45.562 reconstruction 40.497 kl 5.065 val loss total 46.065 reconstruction 40.979 kl 5.086 0:05:17.991716
23/200 1875 loss: total 45.533 reconstruction 40.431 kl 5.102 val loss total 45.621 reconstruction 40.551 kl 5.070 0:05:34.591369
24/200 1875 loss: total 45.472 reconstruction 40.364 kl 5.107 val loss total 45.815 reconstruction 40.818 kl 4.997 0:05:51.423349
25/200 1875 loss: total 45.377 reconstruction 40.264 kl 5.114 val loss total 45.679 reconstruction 40.624 kl 5.055 0:06:08.204300
26/200 1875 loss: total 45.338 reconstruction 40.203 kl 5.135 val loss total 45.798 reconstruction 40.550 kl 5.248 0:06:24.991477
27/200 1875 loss: total 45.294 reconstruction 40.157 kl 5.137 val loss total 45.682 reconstruction 40.614 kl 5.068 0:06:41.724944
28/200 1875 loss: total 45.259 reconstruction 40.107 kl 5.152 val loss total 45.423 reconstruction 40.295 kl 5.129 0:06:58.683955
29/200 1875 loss: total 45.229 reconstruction 40.077 kl 5.153 val loss total 45.632 reconstruction 40.402 kl 5.230 0:07:15.489602
30/200 1875 loss: total 45.148 reconstruction 39.993 kl 5.156 val loss total 45.333 reconstruction 40.367 kl 4.966 0:07:32.348106
31/200 1875 loss: total 45.064 reconstruction 39.895 kl 5.169 val loss total 45.687 reconstruction 40.597 kl 5.091 0:07:48.960279
32/200 1875 loss: total 45.056 reconstruction 39.872 kl 5.184 val loss total 45.327 reconstruction 40.105 kl 5.222 0:08:05.689345
33/200 1875 loss: total 45.022 reconstruction 39.833 kl 5.188 val loss total 45.747 reconstruction 40.727 kl 5.020 0:08:22.310960
34/200 1875 loss: total 44.959 reconstruction 39.769 kl 5.190 val loss total 45.642 reconstruction 40.386 kl 5.256 0:08:38.971414
35/200 1875 loss: total 44.928 reconstruction 39.731 kl 5.196 val loss total 45.334 reconstruction 40.153 kl 5.182 0:08:55.833276
36/200 1875 loss: total 44.873 reconstruction 39.652 kl 5.221 val loss total 45.349 reconstruction 40.222 kl 5.127 0:09:12.681772
37/200 1875 loss: total 44.908 reconstruction 39.690 kl 5.218 val loss total 45.325 reconstruction 40.216 kl 5.109 0:09:29.386281
38/200 1875 loss: total 44.829 reconstruction 39.609 kl 5.221 val loss total 45.356 reconstruction 40.176 kl 5.180 0:09:46.434155
39/200 1875 loss: total 44.807 reconstruction 39.582 kl 5.225 val loss total 45.302 reconstruction 40.268 kl 5.034 0:10:03.115526
40/200 1875 loss: total 44.733 reconstruction 39.506 kl 5.226 val loss total 45.273 reconstruction 40.016 kl 5.257 0:10:19.887071
41/200 1875 loss: total 44.740 reconstruction 39.492 kl 5.248 val loss total 45.072 reconstruction 39.703 kl 5.368 0:10:36.428644
42/200 1875 loss: total 44.681 reconstruction 39.445 kl 5.236 val loss total 45.187 reconstruction 40.148 kl 5.039 0:10:53.151162
43/200 1875 loss: total 44.670 reconstruction 39.417 kl 5.252 val loss total 45.718 reconstruction 40.213 kl 5.505 0:11:09.690726
44/200 1875 loss: total 44.639 reconstruction 39.384 kl 5.255 val loss total 44.865 reconstruction 39.639 kl 5.226 0:11:26.514591
45/200 1875 loss: total 44.595 reconstruction 39.344 kl 5.252 val loss total 45.539 reconstruction 40.319 kl 5.219 0:11:43.236509
46/200 1875 loss: total 44.584 reconstruction 39.322 kl 5.262 val loss total 44.822 reconstruction 39.555 kl 5.267 0:11:59.872629
47/200 1875 loss: total 44.535 reconstruction 39.261 kl 5.274 val loss total 44.937 reconstruction 39.704 kl 5.233 0:12:16.490185
48/200 1875 loss: total 44.543 reconstruction 39.269 kl 5.275 val loss total 45.009 reconstruction 39.745 kl 5.264 0:12:33.227318
49/200 1875 loss: total 44.497 reconstruction 39.224 kl 5.273 val loss total 44.963 reconstruction 39.664 kl 5.299 0:12:50.437077
50/200 1875 loss: total 44.467 reconstruction 39.183 kl 5.284 val loss total 44.915 reconstruction 39.677 kl 5.238 0:13:09.232365
51/200 1875 loss: total 44.436 reconstruction 39.159 kl 5.276 val loss total 44.993 reconstruction 39.640 kl 5.353 0:13:25.996844
52/200 1875 loss: total 44.412 reconstruction 39.120 kl 5.291 val loss total 44.940 reconstruction 39.749 kl 5.191 0:13:42.650136
53/200 1875 loss: total 44.401 reconstruction 39.101 kl 5.299 val loss total 45.460 reconstruction 40.218 kl 5.241 0:13:59.630201
54/200 1875 loss: total 44.378 reconstruction 39.074 kl 5.304 val loss total 44.902 reconstruction 39.627 kl 5.275 0:14:16.244289
55/200 1875 loss: total 44.329 reconstruction 39.027 kl 5.302 val loss total 44.698 reconstruction 39.412 kl 5.286 0:14:32.927279
56/200 1875 loss: total 44.328 reconstruction 39.021 kl 5.306 val loss total 44.774 reconstruction 39.458 kl 5.316 0:14:49.786937
57/200 1875 loss: total 44.311 reconstruction 39.006 kl 5.305 val loss total 45.276 reconstruction 40.009 kl 5.268 0:15:06.662111
58/200 1875 loss: total 44.268 reconstruction 38.944 kl 5.324 val loss total 44.543 reconstruction 39.329 kl 5.214 0:15:23.389418
59/200 1875 loss: total 44.256 reconstruction 38.936 kl 5.320 val loss total 44.842 reconstruction 39.525 kl 5.318 0:15:40.071372
60/200 1875 loss: total 44.224 reconstruction 38.911 kl 5.313 val loss total 45.060 reconstruction 39.890 kl 5.170 0:15:56.721918
61/200 1875 loss: total 44.225 reconstruction 38.897 kl 5.328 val loss total 44.975 reconstruction 39.801 kl 5.174 0:16:13.475832
62/200 1875 loss: total 44.174 reconstruction 38.834 kl 5.340 val loss total 44.882 reconstruction 39.777 kl 5.105 0:16:30.027535
63/200 1875 loss: total 44.199 reconstruction 38.860 kl 5.338 val loss total 44.548 reconstruction 39.299 kl 5.248 0:16:46.791431
64/200 1875 loss: total 44.167 reconstruction 38.819 kl 5.348 val loss total 44.697 reconstruction 39.299 kl 5.398 0:17:03.489724
65/200 1875 loss: total 44.148 reconstruction 38.797 kl 5.352 val loss total 44.866 reconstruction 39.424 kl 5.442 0:17:20.216007
66/200 1875 loss: total 44.130 reconstruction 38.771 kl 5.359 val loss total 44.923 reconstruction 39.618 kl 5.305 0:17:36.991212
67/200 1875 loss: total 44.112 reconstruction 38.754 kl 5.358 val loss total 44.940 reconstruction 39.698 kl 5.242 0:17:53.629803
68/200 1875 loss: total 44.103 reconstruction 38.747 kl 5.356 val loss total 44.848 reconstruction 39.356 kl 5.492 0:18:10.239933
69/200 1875 loss: total 44.067 reconstruction 38.705 kl 5.362 val loss total 44.600 reconstruction 39.351 kl 5.249 0:18:26.904989
70/200 1875 loss: total 44.053 reconstruction 38.696 kl 5.357 val loss total 44.939 reconstruction 39.447 kl 5.492 0:18:43.700989
71/200 1875 loss: total 44.060 reconstruction 38.694 kl 5.366 val loss total 44.609 reconstruction 39.252 kl 5.356 0:19:00.706969
72/200 1875 loss: total 44.042 reconstruction 38.678 kl 5.364 val loss total 44.712 reconstruction 39.334 kl 5.378 0:19:17.432070
73/200 1875 loss: total 43.982 reconstruction 38.624 kl 5.358 val loss total 44.594 reconstruction 39.255 kl 5.339 0:19:34.265796
74/200 1875 loss: total 44.001 reconstruction 38.636 kl 5.365 val loss total 44.569 reconstruction 39.131 kl 5.438 0:19:50.930350
75/200 1875 loss: total 43.987 reconstruction 38.626 kl 5.362 val loss total 44.847 reconstruction 39.518 kl 5.329 0:20:07.967674
76/200 1875 loss: total 43.970 reconstruction 38.596 kl 5.374 val loss total 44.979 reconstruction 39.569 kl 5.410 0:20:24.662119
77/200 1875 loss: total 43.932 reconstruction 38.559 kl 5.373 val loss total 44.545 reconstruction 39.293 kl 5.252 0:20:41.438850
78/200 1875 loss: total 43.957 reconstruction 38.573 kl 5.384 val loss total 44.696 reconstruction 39.378 kl 5.318 0:20:58.196325
79/200 1875 loss: total 43.928 reconstruction 38.545 kl 5.383 val loss total 44.910 reconstruction 39.715 kl 5.195 0:21:14.862080
80/200 1875 loss: total 43.938 reconstruction 38.566 kl 5.372 val loss total 44.910 reconstruction 39.632 kl 5.277 0:21:31.508570
81/200 1875 loss: total 43.895 reconstruction 38.515 kl 5.380 val loss total 44.606 reconstruction 39.231 kl 5.375 0:21:48.273321
82/200 1875 loss: total 43.883 reconstruction 38.490 kl 5.393 val loss total 44.540 reconstruction 39.222 kl 5.318 0:22:05.051810
83/200 1875 loss: total 43.852 reconstruction 38.447 kl 5.405 val loss total 44.589 reconstruction 39.331 kl 5.258 0:22:21.884867
84/200 1875 loss: total 43.878 reconstruction 38.487 kl 5.391 val loss total 44.797 reconstruction 39.560 kl 5.237 0:22:38.489883
85/200 1875 loss: total 43.858 reconstruction 38.464 kl 5.394 val loss total 44.503 reconstruction 38.991 kl 5.513 0:22:55.297158
86/200 1875 loss: total 43.847 reconstruction 38.444 kl 5.403 val loss total 45.060 reconstruction 39.555 kl 5.504 0:23:12.085322
87/200 1875 loss: total 43.818 reconstruction 38.421 kl 5.398 val loss total 44.583 reconstruction 39.236 kl 5.347 0:23:28.936725
88/200 1875 loss: total 43.811 reconstruction 38.395 kl 5.415 val loss total 44.657 reconstruction 39.300 kl 5.357 0:23:45.700052
89/200 1875 loss: total 43.807 reconstruction 38.400 kl 5.407 val loss total 44.839 reconstruction 39.446 kl 5.393 0:24:02.718024
90/200 1875 loss: total 43.798 reconstruction 38.382 kl 5.416 val loss total 44.587 reconstruction 39.109 kl 5.478 0:24:19.397046
91/200 1875 loss: total 43.758 reconstruction 38.338 kl 5.421 val loss total 44.654 reconstruction 39.212 kl 5.442 0:24:36.273507
92/200 1875 loss: total 43.762 reconstruction 38.345 kl 5.417 val loss total 44.715 reconstruction 39.360 kl 5.354 0:24:53.079947
93/200 1875 loss: total 43.736 reconstruction 38.316 kl 5.420 val loss total 44.789 reconstruction 39.331 kl 5.458 0:25:10.138744
94/200 1875 loss: total 43.741 reconstruction 38.318 kl 5.423 val loss total 44.553 reconstruction 39.193 kl 5.360 0:25:26.931501
95/200 1875 loss: total 43.734 reconstruction 38.323 kl 5.412 val loss total 44.566 reconstruction 39.231 kl 5.334 0:25:43.715210
96/200 1875 loss: total 43.698 reconstruction 38.285 kl 5.412 val loss total 44.548 reconstruction 39.247 kl 5.301 0:26:00.400480
97/200 1875 loss: total 43.703 reconstruction 38.270 kl 5.433 val loss total 44.535 reconstruction 38.953 kl 5.582 0:26:17.125616
98/200 1875 loss: total 43.697 reconstruction 38.281 kl 5.416 val loss total 44.915 reconstruction 39.559 kl 5.356 0:26:33.740765
99/200 1875 loss: total 43.664 reconstruction 38.232 kl 5.432 val loss total 44.571 reconstruction 39.142 kl 5.430 0:26:50.462840
100/200 1875 loss: total 43.717 reconstruction 38.277 kl 5.439 val loss total 44.529 reconstruction 39.154 kl 5.375 0:27:08.713689
101/200 1875 loss: total 43.698 reconstruction 38.262 kl 5.436 val loss total 44.647 reconstruction 39.233 kl 5.414 0:27:25.718150
102/200 1875 loss: total 43.660 reconstruction 38.225 kl 5.435 val loss total 44.906 reconstruction 39.429 kl 5.477 0:27:42.409837
103/200 1875 loss: total 43.668 reconstruction 38.217 kl 5.451 val loss total 45.036 reconstruction 39.725 kl 5.311 0:27:59.141782
104/200 1875 loss: total 43.621 reconstruction 38.194 kl 5.427 val loss total 44.354 reconstruction 38.913 kl 5.440 0:28:15.802837
105/200 1875 loss: total 43.633 reconstruction 38.199 kl 5.435 val loss total 44.493 reconstruction 39.069 kl 5.424 0:28:32.486469
106/200 1875 loss: total 43.632 reconstruction 38.187 kl 5.444 val loss total 44.539 reconstruction 39.130 kl 5.408 0:28:49.381544
107/200 1875 loss: total 43.619 reconstruction 38.170 kl 5.450 val loss total 44.381 reconstruction 38.927 kl 5.453 0:29:06.229204
108/200 1875 loss: total 43.597 reconstruction 38.161 kl 5.436 val loss total 44.457 reconstruction 39.116 kl 5.341 0:29:22.983122
109/200 1875 loss: total 43.634 reconstruction 38.190 kl 5.444 val loss total 44.476 reconstruction 39.043 kl 5.433 0:29:39.704490
110/200 1875 loss: total 43.547 reconstruction 38.111 kl 5.435 val loss total 44.423 reconstruction 38.924 kl 5.500 0:29:56.461920
111/200 1875 loss: total 43.598 reconstruction 38.147 kl 5.452 val loss total 44.405 reconstruction 39.003 kl 5.402 0:30:13.250631
112/200 1875 loss: total 43.567 reconstruction 38.128 kl 5.439 val loss total 44.861 reconstruction 39.518 kl 5.342 0:30:30.170810
113/200 1875 loss: total 43.531 reconstruction 38.089 kl 5.442 val loss total 44.599 reconstruction 39.228 kl 5.371 0:30:46.987956
114/200 1875 loss: total 43.562 reconstruction 38.102 kl 5.460 val loss total 44.872 reconstruction 39.365 kl 5.507 0:31:03.791419
115/200 1875 loss: total 43.552 reconstruction 38.096 kl 5.456 val loss total 44.457 reconstruction 38.982 kl 5.476 0:31:20.504139
116/200 1875 loss: total 43.574 reconstruction 38.116 kl 5.458 val loss total 44.843 reconstruction 39.337 kl 5.506 0:31:37.326019
117/200 1875 loss: total 43.549 reconstruction 38.092 kl 5.456 val loss total 44.368 reconstruction 38.947 kl 5.421 0:31:54.163002
118/200 1875 loss: total 43.541 reconstruction 38.086 kl 5.454 val loss total 44.516 reconstruction 39.079 kl 5.437 0:32:10.976940
119/200 1875 loss: total 43.506 reconstruction 38.044 kl 5.462 val loss total 44.526 reconstruction 39.054 kl 5.472 0:32:27.707200
120/200 1875 loss: total 43.530 reconstruction 38.061 kl 5.469 val loss total 44.451 reconstruction 39.100 kl 5.351 0:32:44.453930
121/200 1875 loss: total 43.511 reconstruction 38.056 kl 5.456 val loss total 44.507 reconstruction 39.091 kl 5.416 0:33:01.145372
122/200 1875 loss: total 43.515 reconstruction 38.042 kl 5.473 val loss total 44.273 reconstruction 38.865 kl 5.408 0:33:17.951876
123/200 1875 loss: total 43.491 reconstruction 38.031 kl 5.460 val loss total 44.403 reconstruction 38.976 kl 5.427 0:33:34.690088
124/200 1875 loss: total 43.510 reconstruction 38.031 kl 5.479 val loss total 44.164 reconstruction 38.672 kl 5.492 0:33:51.701289
125/200 1875 loss: total 43.448 reconstruction 37.982 kl 5.466 val loss total 44.394 reconstruction 38.879 kl 5.515 0:34:08.485789
126/200 1875 loss: total 43.483 reconstruction 38.014 kl 5.469 val loss total 44.424 reconstruction 38.998 kl 5.425 0:34:25.384367
127/200 1875 loss: total 43.448 reconstruction 37.961 kl 5.487 val loss total 44.410 reconstruction 39.050 kl 5.360 0:34:42.121120
128/200 1875 loss: total 43.471 reconstruction 37.989 kl 5.482 val loss total 44.356 reconstruction 38.798 kl 5.558 0:34:59.016617
129/200 1875 loss: total 43.445 reconstruction 37.963 kl 5.482 val loss total 44.448 reconstruction 39.027 kl 5.420 0:35:15.643733
130/200 1875 loss: total 43.450 reconstruction 37.971 kl 5.479 val loss total 44.686 reconstruction 39.340 kl 5.346 0:35:32.643387
131/200 1875 loss: total 43.439 reconstruction 37.964 kl 5.476 val loss total 44.323 reconstruction 38.969 kl 5.355 0:35:49.531323
132/200 1875 loss: total 43.422 reconstruction 37.937 kl 5.486 val loss total 44.458 reconstruction 38.916 kl 5.543 0:36:06.274550
133/200 1875 loss: total 43.433 reconstruction 37.955 kl 5.478 val loss total 44.712 reconstruction 39.260 kl 5.452 0:36:23.000081
134/200 1875 loss: total 43.398 reconstruction 37.916 kl 5.482 val loss total 44.731 reconstruction 39.364 kl 5.367 0:36:39.782398
135/200 1875 loss: total 43.401 reconstruction 37.913 kl 5.488 val loss total 44.501 reconstruction 39.107 kl 5.394 0:36:56.623039
136/200 1875 loss: total 43.420 reconstruction 37.930 kl 5.490 val loss total 44.450 reconstruction 39.059 kl 5.391 0:37:13.470758
137/200 1875 loss: total 43.378 reconstruction 37.892 kl 5.486 val loss total 44.338 reconstruction 38.941 kl 5.397 0:37:30.314993
138/200 1875 loss: total 43.404 reconstruction 37.908 kl 5.496 val loss total 44.687 reconstruction 39.322 kl 5.365 0:37:47.283535
139/200 1875 loss: total 43.369 reconstruction 37.891 kl 5.478 val loss total 44.419 reconstruction 38.946 kl 5.473 0:38:04.040351
140/200 1875 loss: total 43.353 reconstruction 37.854 kl 5.498 val loss total 44.724 reconstruction 39.315 kl 5.409 0:38:20.965275
141/200 1875 loss: total 43.339 reconstruction 37.844 kl 5.494 val loss total 44.756 reconstruction 39.376 kl 5.380 0:38:37.634737
142/200 1875 loss: total 43.363 reconstruction 37.858 kl 5.505 val loss total 44.629 reconstruction 39.214 kl 5.415 0:38:54.808788
143/200 1875 loss: total 43.333 reconstruction 37.835 kl 5.498 val loss total 44.383 reconstruction 38.959 kl 5.423 0:39:11.579696
144/200 1875 loss: total 43.363 reconstruction 37.860 kl 5.503 val loss total 44.387 reconstruction 38.886 kl 5.501 0:39:28.543206
145/200 1875 loss: total 43.345 reconstruction 37.854 kl 5.491 val loss total 44.623 reconstruction 39.179 kl 5.444 0:39:45.339762
146/200 1875 loss: total 43.347 reconstruction 37.857 kl 5.491 val loss total 44.545 reconstruction 39.038 kl 5.507 0:40:02.435161
147/200 1875 loss: total 43.344 reconstruction 37.846 kl 5.498 val loss total 44.422 reconstruction 38.846 kl 5.576 0:40:19.238820
148/200 1875 loss: total 43.311 reconstruction 37.794 kl 5.517 val loss total 44.926 reconstruction 39.431 kl 5.495 0:40:36.204852
149/200 1875 loss: total 43.294 reconstruction 37.802 kl 5.493 val loss total 44.266 reconstruction 38.886 kl 5.379 0:40:53.304109
150/200 1875 loss: total 43.289 reconstruction 37.787 kl 5.502 val loss total 44.467 reconstruction 39.091 kl 5.376 0:41:12.059408
151/200 1875 loss: total 43.318 reconstruction 37.828 kl 5.490 val loss total 44.304 reconstruction 38.797 kl 5.507 0:41:28.942447
152/200 1875 loss: total 43.319 reconstruction 37.804 kl 5.515 val loss total 44.797 reconstruction 39.314 kl 5.483 0:41:45.770619
153/200 1875 loss: total 43.280 reconstruction 37.776 kl 5.504 val loss total 44.282 reconstruction 38.878 kl 5.403 0:42:02.625198
154/200 1875 loss: total 43.288 reconstruction 37.778 kl 5.510 val loss total 44.405 reconstruction 38.856 kl 5.549 0:42:19.556715
155/200 1875 loss: total 43.279 reconstruction 37.759 kl 5.520 val loss total 44.326 reconstruction 38.885 kl 5.441 0:42:36.327122
156/200 1875 loss: total 43.279 reconstruction 37.759 kl 5.519 val loss total 44.640 reconstruction 39.144 kl 5.497 0:42:53.148910
157/200 1875 loss: total 43.291 reconstruction 37.780 kl 5.511 val loss total 44.781 reconstruction 39.342 kl 5.439 0:43:09.944585
158/200 1875 loss: total 43.252 reconstruction 37.731 kl 5.521 val loss total 44.255 reconstruction 38.704 kl 5.551 0:43:26.725540
159/200 1875 loss: total 43.258 reconstruction 37.755 kl 5.504 val loss total 44.274 reconstruction 38.728 kl 5.547 0:43:43.660450
160/200 1875 loss: total 43.246 reconstruction 37.732 kl 5.514 val loss total 44.526 reconstruction 39.149 kl 5.377 0:44:00.569521
161/200 1875 loss: total 43.256 reconstruction 37.730 kl 5.526 val loss total 44.721 reconstruction 39.360 kl 5.361 0:44:17.447056
162/200 1875 loss: total 43.247 reconstruction 37.739 kl 5.508 val loss total 44.347 reconstruction 38.975 kl 5.372 0:44:34.201379
163/200 1875 loss: total 43.264 reconstruction 37.749 kl 5.515 val loss total 44.633 reconstruction 39.102 kl 5.531 0:44:51.069020
164/200 1875 loss: total 43.232 reconstruction 37.700 kl 5.533 val loss total 44.235 reconstruction 38.727 kl 5.508 0:45:07.936432
165/200 1875 loss: total 43.222 reconstruction 37.695 kl 5.527 val loss total 44.676 reconstruction 39.089 kl 5.586 0:45:24.745173
166/200 1875 loss: total 43.237 reconstruction 37.694 kl 5.543 val loss total 44.603 reconstruction 39.044 kl 5.559 0:45:41.555163
167/200 1875 loss: total 43.201 reconstruction 37.679 kl 5.522 val loss total 45.117 reconstruction 39.421 kl 5.696 0:45:58.754834
168/200 1875 loss: total 43.179 reconstruction 37.651 kl 5.528 val loss total 44.400 reconstruction 38.888 kl 5.513 0:46:15.605216
169/200 1875 loss: total 43.207 reconstruction 37.675 kl 5.532 val loss total 44.796 reconstruction 39.286 kl 5.510 0:46:32.352851
170/200 1875 loss: total 43.193 reconstruction 37.669 kl 5.525 val loss total 44.349 reconstruction 38.851 kl 5.497 0:46:49.054212
171/200 1875 loss: total 43.164 reconstruction 37.631 kl 5.533 val loss total 44.417 reconstruction 38.912 kl 5.504 0:47:06.073833
172/200 1875 loss: total 43.193 reconstruction 37.663 kl 5.530 val loss total 44.523 reconstruction 39.029 kl 5.494 0:47:23.184783
173/200 1875 loss: total 43.198 reconstruction 37.671 kl 5.528 val loss total 44.748 reconstruction 39.097 kl 5.651 0:47:40.272333
174/200 1875 loss: total 43.150 reconstruction 37.617 kl 5.533 val loss total 44.320 reconstruction 38.973 kl 5.347 0:47:57.381127
175/200 1875 loss: total 43.152 reconstruction 37.614 kl 5.538 val loss total 44.188 reconstruction 38.637 kl 5.551 0:48:14.305486
176/200 1875 loss: total 43.169 reconstruction 37.636 kl 5.533 val loss total 44.226 reconstruction 38.675 kl 5.550 0:48:30.979480
177/200 1875 loss: total 43.198 reconstruction 37.672 kl 5.527 val loss total 44.064 reconstruction 38.547 kl 5.517 0:48:47.847159
178/200 1875 loss: total 43.153 reconstruction 37.620 kl 5.533 val loss total 44.566 reconstruction 39.000 kl 5.566 0:49:04.541262
179/200 1875 loss: total 43.155 reconstruction 37.620 kl 5.535 val loss total 44.416 reconstruction 38.958 kl 5.458 0:49:21.438403
180/200 1875 loss: total 43.150 reconstruction 37.615 kl 5.536 val loss total 44.406 reconstruction 38.900 kl 5.506 0:49:38.162147
181/200 1875 loss: total 43.160 reconstruction 37.620 kl 5.540 val loss total 44.964 reconstruction 39.444 kl 5.520 0:49:55.020021
182/200 1875 loss: total 43.139 reconstruction 37.600 kl 5.539 val loss total 44.641 reconstruction 39.146 kl 5.495 0:50:11.784347
183/200 1875 loss: total 43.169 reconstruction 37.631 kl 5.538 val loss total 44.338 reconstruction 38.913 kl 5.425 0:50:28.659054
184/200 1875 loss: total 43.115 reconstruction 37.565 kl 5.551 val loss total 44.322 reconstruction 38.887 kl 5.435 0:50:45.484843
185/200 1875 loss: total 43.120 reconstruction 37.580 kl 5.539 val loss total 44.345 reconstruction 38.813 kl 5.532 0:51:02.358411
186/200 1875 loss: total 43.093 reconstruction 37.539 kl 5.554 val loss total 44.605 reconstruction 38.971 kl 5.634 0:51:19.055859
187/200 1875 loss: total 43.127 reconstruction 37.591 kl 5.535 val loss total 44.187 reconstruction 38.602 kl 5.584 0:51:35.771913
188/200 1875 loss: total 43.126 reconstruction 37.586 kl 5.540 val loss total 44.505 reconstruction 38.993 kl 5.512 0:51:52.499624
189/200 1875 loss: total 43.105 reconstruction 37.561 kl 5.544 val loss total 44.451 reconstruction 39.028 kl 5.424 0:52:09.349090
190/200 1875 loss: total 43.107 reconstruction 37.575 kl 5.533 val loss total 44.518 reconstruction 38.827 kl 5.691 0:52:26.095645
191/200 1875 loss: total 43.099 reconstruction 37.544 kl 5.556 val loss total 44.444 reconstruction 38.900 kl 5.543 0:52:42.889732
192/200 1875 loss: total 43.106 reconstruction 37.557 kl 5.549 val loss total 44.614 reconstruction 39.069 kl 5.545 0:52:59.796363
193/200 1875 loss: total 43.087 reconstruction 37.535 kl 5.551 val loss total 44.505 reconstruction 38.968 kl 5.538 0:53:16.926831
194/200 1875 loss: total 43.075 reconstruction 37.527 kl 5.548 val loss total 44.447 reconstruction 38.995 kl 5.452 0:53:33.688983
195/200 1875 loss: total 43.087 reconstruction 37.533 kl 5.554 val loss total 44.354 reconstruction 38.872 kl 5.482 0:53:50.606192
196/200 1875 loss: total 43.090 reconstruction 37.535 kl 5.555 val loss total 44.326 reconstruction 38.884 kl 5.442 0:54:07.173908
197/200 1875 loss: total 43.068 reconstruction 37.527 kl 5.542 val loss total 44.311 reconstruction 38.868 kl 5.443 0:54:23.983531
198/200 1875 loss: total 43.096 reconstruction 37.554 kl 5.542 val loss total 44.420 reconstruction 38.883 kl 5.537 0:54:40.752294
199/200 1875 loss: total 43.055 reconstruction 37.498 kl 5.557 val loss total 44.232 reconstruction 38.625 kl 5.607 0:54:57.443065
200/200 1875 loss: total 43.043 reconstruction 37.498 kl 5.546 val loss total 44.157 reconstruction 38.579 kl 5.578 0:55:15.500341
In [40]:
loss2_2 = log2_2['loss']
rloss2_2 = log2_2['reconstruction_loss']
kloss2_2 = log2_2['kl_loss']
val_loss2_2 = log2_2['val_loss']
val_rloss2_2 = log2_2['val_reconstruction_loss']
val_kloss2_2 = log2_2['val_kl_loss']
In [41]:
loss2 = np.concatenate([loss2_1, loss2_2], axis=0)
rloss2 = np.concatenate([rloss2_1, rloss2_2], axis=0)
kloss2 = np.concatenate([kloss2_1, kloss2_2], axis=0)

val_loss2 = np.concatenate([val_loss2_1, val_loss2_2], axis=0)
val_rloss2 = np.concatenate([val_rloss2_1, val_rloss2_2], axis=0)
val_kloss2 = np.concatenate([val_kloss2_1, val_kloss2_2], axis=0)
In [42]:
VariationalAutoEncoder.plot_history(
    [loss2, val_loss2], 
    ['total_loss', 'val_total_loss']
)
In [43]:
VariationalAutoEncoder.plot_history(
    [rloss2, val_rloss2], 
    ['reconstruction_loss', 'val_reconstruction_loss']
)
In [44]:
VariationalAutoEncoder.plot_history(
    [kloss2, val_kloss2], 
    ['kl_loss', 'val_kl_loss']
)
In [45]:
z_mean2, z_log_var2, z2 = vae2_work.encoder(selected_images)
reconst_images2 = vae2_work.decoder(z2).numpy()  # @tf.function 宣言のためdecoder() の返り値はTensorになっているのでnumpyの配列に変換する

txts2 = [f'{p[0]:.3f}, {p[1]:.3f}' for p in z2 ]
In [46]:
%matplotlib inline

VariationalAutoEncoder.showImages(selected_images, reconst_images2, txts2, 1.4, 1.4)

(3) Trainig with tf.GradientTape() function and Learning rate decay

Calculate the loss and gradients with the tf.GradientTape() function, and apply the gradients to the variables. In addition, perform Learning rate decay in the optimizer.

(3) tf.GradientTape() 関数と学習率減数を使った学習

tf.GradientTape() 関数を使って loss と gradients を計算して、gradients を変数に適用する。 さらに、optimizer において Learning rate decay を行う。

In [47]:
save_path3 = '/content/drive/MyDrive/ColabRun/VAE03/'
In [48]:
from nw.VariationalAutoEncoder import VariationalAutoEncoder

vae3 = VariationalAutoEncoder(
    input_dim = (28, 28, 1),
    encoder_conv_filters = [32, 64, 64, 64],
    encoder_conv_kernel_size = [3, 3, 3, 3],
    encoder_conv_strides = [1, 2, 2, 1],
    decoder_conv_t_filters = [64, 64, 32, 1],
    decoder_conv_t_kernel_size = [3, 3, 3, 3],
    decoder_conv_t_strides = [1, 2, 2, 1],
    z_dim = 2,
    r_loss_factor = 1000 
)
In [49]:
# initial_learning_rate * decay_rate ^ (step // decay_steps)

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate = learning_rate,
    decay_steps = 1000,
    decay_rate=0.96
)

optimizer3 = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
In [50]:
log3_1 = vae3.train_tf(
    x_train,
    batch_size = 32,
    epochs = 3,
    shuffle=True,
    run_folder = save_path3,
    optimizer = optimizer3,
    save_epoch_interval=50,
    validation_data=x_test
)
1/3 1875 loss: total 58.493 reconstruction 55.239 kl 3.254 val loss total 53.440 reconstruction 49.923 kl 3.517 0:00:17.739827
2/3 1875 loss: total 51.609 reconstruction 47.612 kl 3.997 val loss total 50.649 reconstruction 46.397 kl 4.252 0:00:34.873874
3/3 1875 loss: total 50.013 reconstruction 45.694 kl 4.319 val loss total 49.614 reconstruction 44.967 kl 4.647 0:00:52.016673
In [51]:
print(log3_1.keys())

loss3_1 = log3_1['loss']
rloss3_1 = log3_1['reconstruction_loss']
kloss3_1 = log3_1['kl_loss']
val_loss3_1 = log3_1['val_loss']
val_rloss3_1 = log3_1['val_reconstruction_loss']
val_kloss3_1 = log3_1['val_kl_loss']
dict_keys(['loss', 'reconstruction_loss', 'kl_loss', 'val_loss', 'val_reconstruction_loss', 'val_kl_loss'])
In [52]:
# 保存したパラメータと重みを読み込む

vae3_work = VariationalAutoEncoder.load(save_path3)
print(vae3_work.epoch)
3
In [53]:
# 追加で training する。
log3_2 = vae3_work.train_tf(
    x_train,
    batch_size = 32,
    epochs = MAX_EPOCHS,
    shuffle=True,
    run_folder = save_path3,
    optimizer = optimizer3,
    save_epoch_interval=50,
    validation_data=x_test
)
4/200 1875 loss: total 49.658 reconstruction 45.263 kl 4.395 val loss total 48.749 reconstruction 44.186 kl 4.563 0:00:17.616743
5/200 1875 loss: total 48.326 reconstruction 43.730 kl 4.596 val loss total 47.960 reconstruction 43.345 kl 4.614 0:00:34.516737
6/200 1875 loss: total 47.774 reconstruction 43.062 kl 4.713 val loss total 47.793 reconstruction 42.888 kl 4.905 0:00:51.485786
7/200 1875 loss: total 47.366 reconstruction 42.606 kl 4.760 val loss total 47.576 reconstruction 42.752 kl 4.824 0:01:08.380952
8/200 1875 loss: total 47.046 reconstruction 42.229 kl 4.817 val loss total 46.938 reconstruction 42.214 kl 4.724 0:01:25.407395
9/200 1875 loss: total 46.772 reconstruction 41.907 kl 4.865 val loss total 46.838 reconstruction 41.859 kl 4.979 0:01:42.374036
10/200 1875 loss: total 46.537 reconstruction 41.651 kl 4.886 val loss total 46.596 reconstruction 41.741 kl 4.855 0:01:59.453270
11/200 1875 loss: total 46.334 reconstruction 41.412 kl 4.922 val loss total 46.542 reconstruction 41.646 kl 4.896 0:02:16.500478
12/200 1875 loss: total 46.159 reconstruction 41.213 kl 4.946 val loss total 46.120 reconstruction 40.986 kl 5.134 0:02:33.443356
13/200 1875 loss: total 45.994 reconstruction 41.017 kl 4.977 val loss total 46.211 reconstruction 41.229 kl 4.982 0:02:50.445165
14/200 1875 loss: total 45.854 reconstruction 40.845 kl 5.010 val loss total 46.023 reconstruction 40.963 kl 5.060 0:03:07.401844
15/200 1875 loss: total 45.701 reconstruction 40.676 kl 5.025 val loss total 45.940 reconstruction 40.938 kl 5.002 0:03:24.243303
16/200 1875 loss: total 45.573 reconstruction 40.530 kl 5.043 val loss total 45.838 reconstruction 40.673 kl 5.166 0:03:41.205999
17/200 1875 loss: total 45.465 reconstruction 40.410 kl 5.055 val loss total 45.770 reconstruction 40.754 kl 5.017 0:03:58.180931
18/200 1875 loss: total 45.338 reconstruction 40.275 kl 5.063 val loss total 45.646 reconstruction 40.577 kl 5.069 0:04:15.244548
19/200 1875 loss: total 45.274 reconstruction 40.185 kl 5.089 val loss total 45.545 reconstruction 40.378 kl 5.166 0:04:32.161583
20/200 1875 loss: total 45.145 reconstruction 40.050 kl 5.096 val loss total 45.616 reconstruction 40.624 kl 4.992 0:04:49.120824
21/200 1875 loss: total 45.091 reconstruction 39.985 kl 5.106 val loss total 45.483 reconstruction 40.361 kl 5.122 0:05:06.047936
22/200 1875 loss: total 45.016 reconstruction 39.878 kl 5.138 val loss total 45.390 reconstruction 40.310 kl 5.080 0:05:22.994144
23/200 1875 loss: total 44.922 reconstruction 39.784 kl 5.138 val loss total 45.303 reconstruction 40.195 kl 5.109 0:05:40.007300
24/200 1875 loss: total 44.867 reconstruction 39.725 kl 5.143 val loss total 45.206 reconstruction 39.997 kl 5.209 0:05:56.867922
25/200 1875 loss: total 44.799 reconstruction 39.647 kl 5.152 val loss total 45.172 reconstruction 40.076 kl 5.096 0:06:13.844618
26/200 1875 loss: total 44.734 reconstruction 39.580 kl 5.155 val loss total 45.151 reconstruction 40.076 kl 5.075 0:06:30.731789
27/200 1875 loss: total 44.687 reconstruction 39.519 kl 5.167 val loss total 45.097 reconstruction 39.935 kl 5.162 0:06:47.798209
28/200 1875 loss: total 44.627 reconstruction 39.453 kl 5.174 val loss total 44.990 reconstruction 39.802 kl 5.188 0:07:04.854064
29/200 1875 loss: total 44.593 reconstruction 39.403 kl 5.190 val loss total 45.101 reconstruction 39.866 kl 5.235 0:07:21.992719
30/200 1875 loss: total 44.552 reconstruction 39.350 kl 5.201 val loss total 44.993 reconstruction 39.789 kl 5.203 0:07:39.136176
31/200 1875 loss: total 44.497 reconstruction 39.305 kl 5.191 val loss total 44.989 reconstruction 39.768 kl 5.220 0:07:56.087327
32/200 1875 loss: total 44.481 reconstruction 39.262 kl 5.218 val loss total 44.969 reconstruction 39.745 kl 5.224 0:08:13.067476
33/200 1875 loss: total 44.407 reconstruction 39.211 kl 5.196 val loss total 44.923 reconstruction 39.727 kl 5.196 0:08:29.940648
34/200 1875 loss: total 44.397 reconstruction 39.187 kl 5.210 val loss total 44.893 reconstruction 39.700 kl 5.193 0:08:46.893910
35/200 1875 loss: total 44.365 reconstruction 39.142 kl 5.223 val loss total 45.044 reconstruction 39.873 kl 5.171 0:09:03.845426
36/200 1875 loss: total 44.341 reconstruction 39.107 kl 5.234 val loss total 44.847 reconstruction 39.639 kl 5.208 0:09:20.738918
37/200 1875 loss: total 44.314 reconstruction 39.079 kl 5.234 val loss total 44.881 reconstruction 39.661 kl 5.220 0:09:37.834646
38/200 1875 loss: total 44.290 reconstruction 39.051 kl 5.239 val loss total 44.874 reconstruction 39.712 kl 5.162 0:09:54.700557
39/200 1875 loss: total 44.257 reconstruction 39.023 kl 5.234 val loss total 44.814 reconstruction 39.629 kl 5.186 0:10:11.610457
40/200 1875 loss: total 44.242 reconstruction 39.004 kl 5.238 val loss total 44.791 reconstruction 39.592 kl 5.198 0:10:28.487917
41/200 1875 loss: total 44.218 reconstruction 38.978 kl 5.240 val loss total 44.816 reconstruction 39.569 kl 5.248 0:10:45.613908
42/200 1875 loss: total 44.209 reconstruction 38.960 kl 5.249 val loss total 44.804 reconstruction 39.562 kl 5.241 0:11:02.583906
43/200 1875 loss: total 44.177 reconstruction 38.931 kl 5.246 val loss total 44.778 reconstruction 39.534 kl 5.245 0:11:19.554249
44/200 1875 loss: total 44.166 reconstruction 38.911 kl 5.254 val loss total 44.814 reconstruction 39.644 kl 5.170 0:11:36.494770
45/200 1875 loss: total 44.147 reconstruction 38.891 kl 5.257 val loss total 44.728 reconstruction 39.469 kl 5.258 0:11:53.552586
46/200 1875 loss: total 44.119 reconstruction 38.873 kl 5.246 val loss total 44.781 reconstruction 39.568 kl 5.213 0:12:10.600668
47/200 1875 loss: total 44.117 reconstruction 38.865 kl 5.252 val loss total 44.770 reconstruction 39.538 kl 5.232 0:12:27.672166
48/200 1875 loss: total 44.102 reconstruction 38.848 kl 5.253 val loss total 44.690 reconstruction 39.448 kl 5.242 0:12:44.810651
49/200 1875 loss: total 44.103 reconstruction 38.836 kl 5.268 val loss total 44.735 reconstruction 39.523 kl 5.212 0:13:01.705094
50/200 1875 loss: total 44.081 reconstruction 38.826 kl 5.255 val loss total 44.755 reconstruction 39.509 kl 5.247 0:13:20.050213
51/200 1875 loss: total 44.078 reconstruction 38.815 kl 5.263 val loss total 44.704 reconstruction 39.473 kl 5.231 0:13:37.115166
52/200 1875 loss: total 44.076 reconstruction 38.807 kl 5.269 val loss total 44.707 reconstruction 39.433 kl 5.274 0:13:54.084028
53/200 1875 loss: total 44.040 reconstruction 38.777 kl 5.263 val loss total 44.717 reconstruction 39.520 kl 5.197 0:14:11.070665
54/200 1875 loss: total 44.039 reconstruction 38.780 kl 5.258 val loss total 44.702 reconstruction 39.442 kl 5.260 0:14:27.848333
55/200 1875 loss: total 44.033 reconstruction 38.766 kl 5.267 val loss total 44.693 reconstruction 39.460 kl 5.232 0:14:44.825219
56/200 1875 loss: total 44.041 reconstruction 38.771 kl 5.270 val loss total 44.678 reconstruction 39.424 kl 5.254 0:15:01.717775
57/200 1875 loss: total 44.034 reconstruction 38.766 kl 5.269 val loss total 44.695 reconstruction 39.425 kl 5.270 0:15:18.633414
58/200 1875 loss: total 44.008 reconstruction 38.740 kl 5.268 val loss total 44.744 reconstruction 39.480 kl 5.265 0:15:35.428503
59/200 1875 loss: total 44.017 reconstruction 38.745 kl 5.273 val loss total 44.681 reconstruction 39.431 kl 5.250 0:15:52.337627
60/200 1875 loss: total 44.019 reconstruction 38.743 kl 5.276 val loss total 44.658 reconstruction 39.402 kl 5.256 0:16:09.232675
61/200 1875 loss: total 44.014 reconstruction 38.733 kl 5.281 val loss total 44.732 reconstruction 39.444 kl 5.289 0:16:26.147037
62/200 1875 loss: total 43.990 reconstruction 38.716 kl 5.274 val loss total 44.631 reconstruction 39.375 kl 5.256 0:16:43.048847
63/200 1875 loss: total 43.975 reconstruction 38.709 kl 5.266 val loss total 44.657 reconstruction 39.411 kl 5.246 0:16:59.897820
64/200 1875 loss: total 44.006 reconstruction 38.727 kl 5.279 val loss total 44.662 reconstruction 39.399 kl 5.263 0:17:16.846156
65/200 1875 loss: total 43.987 reconstruction 38.711 kl 5.276 val loss total 44.684 reconstruction 39.449 kl 5.235 0:17:33.862323
66/200 1875 loss: total 44.004 reconstruction 38.725 kl 5.279 val loss total 44.622 reconstruction 39.358 kl 5.263 0:17:51.177547
67/200 1875 loss: total 43.989 reconstruction 38.705 kl 5.284 val loss total 44.637 reconstruction 39.355 kl 5.282 0:18:08.090772
68/200 1875 loss: total 43.983 reconstruction 38.698 kl 5.285 val loss total 44.675 reconstruction 39.416 kl 5.259 0:18:25.027571
69/200 1875 loss: total 43.965 reconstruction 38.689 kl 5.276 val loss total 44.656 reconstruction 39.399 kl 5.257 0:18:42.010765
70/200 1875 loss: total 43.960 reconstruction 38.688 kl 5.273 val loss total 44.680 reconstruction 39.433 kl 5.247 0:18:59.011968
71/200 1875 loss: total 43.963 reconstruction 38.688 kl 5.274 val loss total 44.615 reconstruction 39.364 kl 5.251 0:19:15.967134
72/200 1875 loss: total 43.963 reconstruction 38.690 kl 5.273 val loss total 44.653 reconstruction 39.398 kl 5.255 0:19:33.001842
73/200 1875 loss: total 43.974 reconstruction 38.693 kl 5.281 val loss total 44.669 reconstruction 39.411 kl 5.258 0:19:49.840520
74/200 1875 loss: total 43.949 reconstruction 38.678 kl 5.272 val loss total 44.685 reconstruction 39.432 kl 5.253 0:20:06.754282
75/200 1875 loss: total 43.964 reconstruction 38.684 kl 5.280 val loss total 44.646 reconstruction 39.391 kl 5.255 0:20:23.710590
76/200 1875 loss: total 43.957 reconstruction 38.679 kl 5.278 val loss total 44.638 reconstruction 39.380 kl 5.257 0:20:40.692416
77/200 1875 loss: total 43.956 reconstruction 38.683 kl 5.272 val loss total 44.682 reconstruction 39.425 kl 5.257 0:20:57.604279
78/200 1875 loss: total 43.958 reconstruction 38.675 kl 5.282 val loss total 44.633 reconstruction 39.379 kl 5.254 0:21:14.509061
79/200 1875 loss: total 43.955 reconstruction 38.679 kl 5.275 val loss total 44.662 reconstruction 39.409 kl 5.253 0:21:31.393506
80/200 1875 loss: total 43.960 reconstruction 38.678 kl 5.281 val loss total 44.699 reconstruction 39.430 kl 5.269 0:21:48.404380
81/200 1875 loss: total 43.952 reconstruction 38.669 kl 5.282 val loss total 44.677 reconstruction 39.415 kl 5.262 0:22:05.221809
82/200 1875 loss: total 43.952 reconstruction 38.675 kl 5.277 val loss total 44.653 reconstruction 39.398 kl 5.254 0:22:22.152764
83/200 1875 loss: total 43.954 reconstruction 38.674 kl 5.280 val loss total 44.640 reconstruction 39.382 kl 5.258 0:22:39.180217
84/200 1875 loss: total 43.949 reconstruction 38.672 kl 5.277 val loss total 44.665 reconstruction 39.411 kl 5.254 0:22:56.274950
85/200 1875 loss: total 43.957 reconstruction 38.676 kl 5.281 val loss total 44.632 reconstruction 39.376 kl 5.256 0:23:13.124952
86/200 1875 loss: total 43.967 reconstruction 38.682 kl 5.285 val loss total 44.642 reconstruction 39.373 kl 5.268 0:23:30.020014
87/200 1875 loss: total 43.954 reconstruction 38.665 kl 5.289 val loss total 44.633 reconstruction 39.366 kl 5.267 0:23:46.979145
88/200 1875 loss: total 43.947 reconstruction 38.662 kl 5.286 val loss total 44.652 reconstruction 39.392 kl 5.260 0:24:03.871217
89/200 1875 loss: total 43.943 reconstruction 38.664 kl 5.279 val loss total 44.596 reconstruction 39.337 kl 5.259 0:24:20.840823
90/200 1875 loss: total 43.945 reconstruction 38.666 kl 5.279 val loss total 44.678 reconstruction 39.421 kl 5.257 0:24:37.875952
91/200 1875 loss: total 43.957 reconstruction 38.677 kl 5.280 val loss total 44.672 reconstruction 39.412 kl 5.260 0:24:54.720054
92/200 1875 loss: total 43.952 reconstruction 38.669 kl 5.283 val loss total 44.673 reconstruction 39.410 kl 5.263 0:25:11.659594
93/200 1875 loss: total 43.951 reconstruction 38.666 kl 5.285 val loss total 44.626 reconstruction 39.364 kl 5.262 0:25:28.669971
94/200 1875 loss: total 43.948 reconstruction 38.666 kl 5.282 val loss total 44.635 reconstruction 39.374 kl 5.261 0:25:45.693580
95/200 1875 loss: total 43.938 reconstruction 38.657 kl 5.281 val loss total 44.637 reconstruction 39.383 kl 5.255 0:26:02.496135
96/200 1875 loss: total 43.952 reconstruction 38.669 kl 5.283 val loss total 44.618 reconstruction 39.353 kl 5.264 0:26:19.480000
97/200 1875 loss: total 43.948 reconstruction 38.664 kl 5.285 val loss total 44.679 reconstruction 39.416 kl 5.263 0:26:36.342742
98/200 1875 loss: total 43.950 reconstruction 38.666 kl 5.284 val loss total 44.681 reconstruction 39.419 kl 5.262 0:26:53.251092
99/200 1875 loss: total 43.935 reconstruction 38.652 kl 5.283 val loss total 44.622 reconstruction 39.361 kl 5.261 0:27:10.224046
100/200 1875 loss: total 43.946 reconstruction 38.666 kl 5.281 val loss total 44.668 reconstruction 39.406 kl 5.262 0:27:28.682459
101/200 1875 loss: total 43.944 reconstruction 38.664 kl 5.281 val loss total 44.629 reconstruction 39.369 kl 5.260 0:27:45.934271
102/200 1875 loss: total 43.949 reconstruction 38.667 kl 5.282 val loss total 44.668 reconstruction 39.406 kl 5.262 0:28:02.983688
103/200 1875 loss: total 43.946 reconstruction 38.663 kl 5.283 val loss total 44.657 reconstruction 39.395 kl 5.262 0:28:19.935472
104/200 1875 loss: total 43.928 reconstruction 38.647 kl 5.281 val loss total 44.669 reconstruction 39.410 kl 5.259 0:28:36.980706
105/200 1875 loss: total 43.957 reconstruction 38.676 kl 5.282 val loss total 44.600 reconstruction 39.338 kl 5.262 0:28:53.968192
106/200 1875 loss: total 43.949 reconstruction 38.666 kl 5.283 val loss total 44.633 reconstruction 39.373 kl 5.261 0:29:10.849823
107/200 1875 loss: total 43.933 reconstruction 38.653 kl 5.280 val loss total 44.681 reconstruction 39.421 kl 5.259 0:29:27.822732
108/200 1875 loss: total 43.937 reconstruction 38.657 kl 5.280 val loss total 44.641 reconstruction 39.382 kl 5.259 0:29:44.927605
109/200 1875 loss: total 43.943 reconstruction 38.662 kl 5.281 val loss total 44.663 reconstruction 39.403 kl 5.260 0:30:01.868636
110/200 1875 loss: total 43.943 reconstruction 38.662 kl 5.281 val loss total 44.715 reconstruction 39.453 kl 5.262 0:30:18.683033
111/200 1875 loss: total 43.941 reconstruction 38.659 kl 5.281 val loss total 44.673 reconstruction 39.413 kl 5.260 0:30:35.620327
112/200 1875 loss: total 43.942 reconstruction 38.661 kl 5.282 val loss total 44.634 reconstruction 39.373 kl 5.260 0:30:52.513338
113/200 1875 loss: total 43.936 reconstruction 38.656 kl 5.281 val loss total 44.627 reconstruction 39.367 kl 5.260 0:31:09.470034
114/200 1875 loss: total 43.948 reconstruction 38.666 kl 5.282 val loss total 44.650 reconstruction 39.388 kl 5.262 0:31:26.374943
115/200 1875 loss: total 43.947 reconstruction 38.664 kl 5.284 val loss total 44.680 reconstruction 39.417 kl 5.263 0:31:43.331110
116/200 1875 loss: total 43.946 reconstruction 38.661 kl 5.284 val loss total 44.618 reconstruction 39.354 kl 5.263 0:32:00.160204
117/200 1875 loss: total 43.932 reconstruction 38.649 kl 5.283 val loss total 44.616 reconstruction 39.353 kl 5.263 0:32:17.110614
118/200 1875 loss: total 43.945 reconstruction 38.662 kl 5.283 val loss total 44.693 reconstruction 39.430 kl 5.263 0:32:34.114515
119/200 1875 loss: total 43.942 reconstruction 38.658 kl 5.284 val loss total 44.667 reconstruction 39.404 kl 5.263 0:32:51.100416
120/200 1875 loss: total 43.943 reconstruction 38.660 kl 5.283 val loss total 44.679 reconstruction 39.417 kl 5.262 0:33:08.075365
121/200 1875 loss: total 43.941 reconstruction 38.658 kl 5.283 val loss total 44.669 reconstruction 39.407 kl 5.262 0:33:24.987791
122/200 1875 loss: total 43.944 reconstruction 38.661 kl 5.282 val loss total 44.657 reconstruction 39.396 kl 5.262 0:33:42.082569
123/200 1875 loss: total 43.950 reconstruction 38.668 kl 5.282 val loss total 44.667 reconstruction 39.406 kl 5.261 0:33:59.461157
124/200 1875 loss: total 43.937 reconstruction 38.655 kl 5.282 val loss total 44.610 reconstruction 39.349 kl 5.261 0:34:16.733678
125/200 1875 loss: total 43.937 reconstruction 38.654 kl 5.282 val loss total 44.647 reconstruction 39.386 kl 5.261 0:34:34.188327
126/200 1875 loss: total 43.933 reconstruction 38.650 kl 5.282 val loss total 44.689 reconstruction 39.428 kl 5.261 0:34:51.088631
127/200 1875 loss: total 43.952 reconstruction 38.670 kl 5.282 val loss total 44.619 reconstruction 39.357 kl 5.261 0:35:08.068828
128/200 1875 loss: total 43.945 reconstruction 38.663 kl 5.282 val loss total 44.649 reconstruction 39.388 kl 5.262 0:35:24.976363
129/200 1875 loss: total 43.955 reconstruction 38.672 kl 5.283 val loss total 44.650 reconstruction 39.388 kl 5.262 0:35:41.957613
130/200 1875 loss: total 43.930 reconstruction 38.648 kl 5.282 val loss total 44.685 reconstruction 39.423 kl 5.261 0:35:58.848861
131/200 1875 loss: total 43.950 reconstruction 38.668 kl 5.282 val loss total 44.658 reconstruction 39.397 kl 5.261 0:36:15.730307
132/200 1875 loss: total 43.944 reconstruction 38.661 kl 5.282 val loss total 44.643 reconstruction 39.381 kl 5.261 0:36:32.735192
133/200 1875 loss: total 43.932 reconstruction 38.650 kl 5.282 val loss total 44.620 reconstruction 39.359 kl 5.261 0:36:49.600559
134/200 1875 loss: total 43.940 reconstruction 38.658 kl 5.282 val loss total 44.645 reconstruction 39.383 kl 5.261 0:37:06.402081
135/200 1875 loss: total 43.931 reconstruction 38.649 kl 5.282 val loss total 44.619 reconstruction 39.358 kl 5.261 0:37:23.151727
136/200 1875 loss: total 43.935 reconstruction 38.653 kl 5.282 val loss total 44.628 reconstruction 39.366 kl 5.262 0:37:40.379000
137/200 1875 loss: total 43.936 reconstruction 38.654 kl 5.282 val loss total 44.634 reconstruction 39.373 kl 5.261 0:37:57.313128
138/200 1875 loss: total 43.926 reconstruction 38.644 kl 5.282 val loss total 44.672 reconstruction 39.411 kl 5.261 0:38:14.199890
139/200 1875 loss: total 43.943 reconstruction 38.661 kl 5.282 val loss total 44.640 reconstruction 39.379 kl 5.261 0:38:31.248823
140/200 1875 loss: total 43.943 reconstruction 38.661 kl 5.282 val loss total 44.633 reconstruction 39.372 kl 5.261 0:38:48.186538
141/200 1875 loss: total 43.940 reconstruction 38.658 kl 5.282 val loss total 44.685 reconstruction 39.424 kl 5.261 0:39:05.054026
142/200 1875 loss: total 43.938 reconstruction 38.656 kl 5.282 val loss total 44.652 reconstruction 39.391 kl 5.261 0:39:22.016986
143/200 1875 loss: total 43.954 reconstruction 38.672 kl 5.282 val loss total 44.670 reconstruction 39.408 kl 5.261 0:39:38.908782
144/200 1875 loss: total 43.920 reconstruction 38.638 kl 5.282 val loss total 44.660 reconstruction 39.399 kl 5.261 0:39:55.807692
145/200 1875 loss: total 43.941 reconstruction 38.659 kl 5.282 val loss total 44.634 reconstruction 39.372 kl 5.261 0:40:12.684311
146/200 1875 loss: total 43.938 reconstruction 38.656 kl 5.282 val loss total 44.591 reconstruction 39.330 kl 5.261 0:40:29.625403
147/200 1875 loss: total 43.942 reconstruction 38.660 kl 5.282 val loss total 44.665 reconstruction 39.404 kl 5.261 0:40:46.576435
148/200 1875 loss: total 43.958 reconstruction 38.676 kl 5.282 val loss total 44.623 reconstruction 39.362 kl 5.261 0:41:03.477569
149/200 1875 loss: total 43.928 reconstruction 38.646 kl 5.282 val loss total 44.668 reconstruction 39.407 kl 5.261 0:41:20.307208
150/200 1875 loss: total 43.937 reconstruction 38.655 kl 5.282 val loss total 44.646 reconstruction 39.385 kl 5.261 0:41:38.880538
151/200 1875 loss: total 43.932 reconstruction 38.650 kl 5.282 val loss total 44.659 reconstruction 39.398 kl 5.261 0:41:55.882208
152/200 1875 loss: total 43.957 reconstruction 38.675 kl 5.282 val loss total 44.652 reconstruction 39.391 kl 5.261 0:42:12.801339
153/200 1875 loss: total 43.941 reconstruction 38.659 kl 5.282 val loss total 44.654 reconstruction 39.393 kl 5.261 0:42:29.583425
154/200 1875 loss: total 43.963 reconstruction 38.681 kl 5.282 val loss total 44.635 reconstruction 39.373 kl 5.261 0:42:46.678708
155/200 1875 loss: total 43.930 reconstruction 38.648 kl 5.282 val loss total 44.655 reconstruction 39.393 kl 5.261 0:43:03.549905
156/200 1875 loss: total 43.943 reconstruction 38.661 kl 5.282 val loss total 44.624 reconstruction 39.363 kl 5.261 0:43:20.456769
157/200 1875 loss: total 43.930 reconstruction 38.648 kl 5.282 val loss total 44.650 reconstruction 39.389 kl 5.261 0:43:37.503061
158/200 1875 loss: total 43.928 reconstruction 38.646 kl 5.282 val loss total 44.624 reconstruction 39.363 kl 5.261 0:43:54.362039
159/200 1875 loss: total 43.944 reconstruction 38.662 kl 5.282 val loss total 44.632 reconstruction 39.370 kl 5.261 0:44:11.219079
160/200 1875 loss: total 43.938 reconstruction 38.656 kl 5.282 val loss total 44.659 reconstruction 39.398 kl 5.261 0:44:28.113912
161/200 1875 loss: total 43.933 reconstruction 38.650 kl 5.282 val loss total 44.660 reconstruction 39.399 kl 5.261 0:44:45.097979
162/200 1875 loss: total 43.953 reconstruction 38.671 kl 5.282 val loss total 44.624 reconstruction 39.362 kl 5.261 0:45:01.963802
163/200 1875 loss: total 43.930 reconstruction 38.648 kl 5.282 val loss total 44.674 reconstruction 39.413 kl 5.261 0:45:18.820460
164/200 1875 loss: total 43.938 reconstruction 38.656 kl 5.282 val loss total 44.649 reconstruction 39.388 kl 5.261 0:45:35.656687
165/200 1875 loss: total 43.926 reconstruction 38.644 kl 5.282 val loss total 44.646 reconstruction 39.385 kl 5.261 0:45:52.580641
166/200 1875 loss: total 43.943 reconstruction 38.661 kl 5.282 val loss total 44.647 reconstruction 39.386 kl 5.261 0:46:09.525616
167/200 1875 loss: total 43.953 reconstruction 38.670 kl 5.282 val loss total 44.635 reconstruction 39.373 kl 5.261 0:46:26.353054
168/200 1875 loss: total 43.930 reconstruction 38.647 kl 5.282 val loss total 44.670 reconstruction 39.408 kl 5.261 0:46:43.235755
169/200 1875 loss: total 43.946 reconstruction 38.664 kl 5.282 val loss total 44.615 reconstruction 39.354 kl 5.261 0:47:00.104168
170/200 1875 loss: total 43.954 reconstruction 38.672 kl 5.282 val loss total 44.652 reconstruction 39.391 kl 5.261 0:47:16.959041
171/200 1875 loss: total 43.948 reconstruction 38.666 kl 5.282 val loss total 44.674 reconstruction 39.413 kl 5.261 0:47:33.885171
172/200 1875 loss: total 43.929 reconstruction 38.647 kl 5.282 val loss total 44.653 reconstruction 39.392 kl 5.261 0:47:50.866081
173/200 1875 loss: total 43.941 reconstruction 38.659 kl 5.282 val loss total 44.650 reconstruction 39.388 kl 5.261 0:48:07.802663
174/200 1875 loss: total 43.938 reconstruction 38.656 kl 5.282 val loss total 44.651 reconstruction 39.389 kl 5.261 0:48:24.663800
175/200 1875 loss: total 43.949 reconstruction 38.667 kl 5.282 val loss total 44.657 reconstruction 39.396 kl 5.261 0:48:41.890727
176/200 1875 loss: total 43.950 reconstruction 38.667 kl 5.282 val loss total 44.661 reconstruction 39.399 kl 5.261 0:48:58.734299
177/200 1875 loss: total 43.947 reconstruction 38.665 kl 5.282 val loss total 44.614 reconstruction 39.352 kl 5.261 0:49:15.652162
178/200 1875 loss: total 43.928 reconstruction 38.645 kl 5.282 val loss total 44.636 reconstruction 39.374 kl 5.261 0:49:32.531961
179/200 1875 loss: total 43.948 reconstruction 38.665 kl 5.282 val loss total 44.654 reconstruction 39.393 kl 5.261 0:49:49.461183
180/200 1875 loss: total 43.939 reconstruction 38.657 kl 5.282 val loss total 44.658 reconstruction 39.396 kl 5.261 0:50:06.401785
181/200 1875 loss: total 43.942 reconstruction 38.660 kl 5.282 val loss total 44.636 reconstruction 39.375 kl 5.261 0:50:23.392329
182/200 1875 loss: total 43.948 reconstruction 38.666 kl 5.282 val loss total 44.657 reconstruction 39.396 kl 5.261 0:50:40.369676
183/200 1875 loss: total 43.938 reconstruction 38.655 kl 5.282 val loss total 44.595 reconstruction 39.334 kl 5.261 0:50:57.446461
184/200 1875 loss: total 43.946 reconstruction 38.663 kl 5.282 val loss total 44.650 reconstruction 39.389 kl 5.261 0:51:15.021246
185/200 1875 loss: total 43.936 reconstruction 38.654 kl 5.282 val loss total 44.635 reconstruction 39.373 kl 5.261 0:51:32.001082
186/200 1875 loss: total 43.935 reconstruction 38.653 kl 5.282 val loss total 44.647 reconstruction 39.386 kl 5.261 0:51:48.915687
187/200 1875 loss: total 43.939 reconstruction 38.657 kl 5.282 val loss total 44.630 reconstruction 39.369 kl 5.261 0:52:05.872774
188/200 1875 loss: total 43.939 reconstruction 38.657 kl 5.282 val loss total 44.634 reconstruction 39.373 kl 5.261 0:52:22.803600
189/200 1875 loss: total 43.956 reconstruction 38.674 kl 5.282 val loss total 44.622 reconstruction 39.360 kl 5.261 0:52:39.974099
190/200 1875 loss: total 43.932 reconstruction 38.650 kl 5.282 val loss total 44.638 reconstruction 39.376 kl 5.261 0:52:57.120551
191/200 1875 loss: total 43.935 reconstruction 38.653 kl 5.282 val loss total 44.649 reconstruction 39.387 kl 5.261 0:53:14.245814
192/200 1875 loss: total 43.943 reconstruction 38.661 kl 5.282 val loss total 44.649 reconstruction 39.388 kl 5.261 0:53:31.120179
193/200 1875 loss: total 43.945 reconstruction 38.663 kl 5.282 val loss total 44.685 reconstruction 39.424 kl 5.261 0:53:48.193309
194/200 1875 loss: total 43.933 reconstruction 38.651 kl 5.282 val loss total 44.640 reconstruction 39.379 kl 5.261 0:54:05.205495
195/200 1875 loss: total 43.938 reconstruction 38.656 kl 5.282 val loss total 44.662 reconstruction 39.401 kl 5.261 0:54:22.157001
196/200 1875 loss: total 43.933 reconstruction 38.651 kl 5.282 val loss total 44.655 reconstruction 39.394 kl 5.261 0:54:39.042646
197/200 1875 loss: total 43.944 reconstruction 38.662 kl 5.282 val loss total 44.642 reconstruction 39.381 kl 5.261 0:54:55.991865
198/200 1875 loss: total 43.934 reconstruction 38.652 kl 5.282 val loss total 44.655 reconstruction 39.394 kl 5.261 0:55:12.922184
199/200 1875 loss: total 43.956 reconstruction 38.674 kl 5.282 val loss total 44.643 reconstruction 39.382 kl 5.261 0:55:29.843118
200/200 1875 loss: total 43.938 reconstruction 38.656 kl 5.282 val loss total 44.658 reconstruction 39.396 kl 5.261 0:55:48.264540
In [54]:
loss3_2 = log3_2['loss']
rloss3_2 = log3_2['reconstruction_loss']
kloss3_2 = log3_2['kl_loss']
val_loss3_2 = log3_2['val_loss']
val_rloss3_2 = log3_2['val_reconstruction_loss']
val_kloss3_2 = log3_2['val_kl_loss']
In [55]:
loss3 = np.concatenate([loss3_1, loss3_2], axis=0)
rloss3 = np.concatenate([rloss3_1, rloss3_2], axis=0)
kloss3 = np.concatenate([kloss3_1, kloss3_2], axis=0)

val_loss3 = np.concatenate([val_loss3_1, val_loss3_2], axis=0)
val_rloss3 = np.concatenate([val_rloss3_1, val_rloss3_2], axis=0)
val_kloss3 = np.concatenate([val_kloss3_1, val_kloss3_2], axis=0)
In [56]:
VariationalAutoEncoder.plot_history(
    [loss3, val_loss3], 
    ['total_loss', 'val_total_loss']
)
In [57]:
VariationalAutoEncoder.plot_history(
    [rloss3, val_rloss3], 
    ['reconstruction_loss', 'val_reconstruction_loss']
)
In [58]:
VariationalAutoEncoder.plot_history(
    [kloss3, val_kloss3], 
    ['kl_loss', 'val_kl_loss']
)
In [59]:
z_mean3, z_log_var3, z3 = vae3_work.encoder(selected_images)
reconst_images3 = vae3_work.decoder(z3).numpy()  # @tf.function 宣言のためdecoder() の返り値はTensorになっているのでnumpyの配列に変換する

txts3 = [f'{p[0]:.3f}, {p[1]:.3f}' for p in z3 ]
In [60]:
%matplotlib inline

VariationalAutoEncoder.showImages(selected_images, reconst_images3, txts3, 1.4, 1.4)
In [60]: