Updated 29/Nov/2021 by Yoshihisa Nitta  

Cycle Generative Adversarial Network for VidTIMIT dataset with Tensorflow 2 on Google Colab (WGAN-GP)

Train Cycle Generative Adversarial Network (CycleGAN) on VidTIMIT dataset.

VidTIMIT データセットに対して Cycle Generative Adversarial Network をGoogle Colab 上の Tensorflow 2 で学習する

VidTIMIT データセットに対して Cycle Generative Adversarial Network (CycleGAN) を学習する。

In [1]:
#! pip install tensorflow==2.7.0
In [2]:
! pip install tensorflow_addons
Requirement already satisfied: tensorflow_addons in /usr/local/lib/python3.7/dist-packages (0.15.0)
Requirement already satisfied: typeguard>=2.7 in /usr/local/lib/python3.7/dist-packages (from tensorflow_addons) (2.7.1)
In [3]:
%tensorflow_version 2.x

import tensorflow as tf
print(tf.__version__)
2.7.0
In [4]:
import numpy as np

np.random.seed(2022)

Check the Google Colab runtime environment

Google Colab 実行環境を調べる

In [5]:
! nvidia-smi
! cat /proc/cpuinfo
! cat /etc/issue
! free -h
Mon Nov 29 12:44:19 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 85
model name	: Intel(R) Xeon(R) CPU @ 2.00GHz
stepping	: 3
microcode	: 0x1
cpu MHz		: 2000.180
cache size	: 39424 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 0
initial apicid	: 0
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4000.36
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 bits virtual
power management:

processor	: 1
vendor_id	: GenuineIntel
cpu family	: 6
model		: 85
model name	: Intel(R) Xeon(R) CPU @ 2.00GHz
stepping	: 3
microcode	: 0x1
cpu MHz		: 2000.180
cache size	: 39424 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 1
initial apicid	: 1
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4000.36
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 bits virtual
power management:

Ubuntu 18.04.5 LTS \n \l

              total        used        free      shared  buff/cache   available
Mem:            12G        906M        2.0G        1.3M        9.8G         11G
Swap:            0B          0B          0B

Mount Google Drive from Google Colab

Google Colab から GoogleDrive をマウントする

In [6]:
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
In [7]:
! ls /content/drive
MyDrive  Shareddrives

Download source file from Google Drive or nw.tsuda.ac.jp

Basically, gdown from Google Drive. Download from nw.tsuda.ac.jp above only if the specifications of Google Drive change and you cannot download from Google Drive.

Google Drive または nw.tsuda.ac.jp からファイルをダウンロードする

基本的に、Google Drive から gdown してください。 Google Drive の仕様が変わってダウンロードができない場合にのみ、nw.tsuda.ac.jp からダウンロードしてください。

In [8]:
# Download source file
nw_path = './nw'
! rm -rf {nw_path}
! mkdir -p {nw_path}

if True:   # from Google Drive
    url_model =  'https://drive.google.com/uc?id=1aNvpPDNeDWYQFu_PA1kOtFlzcO5seHky'
    ! (cd {nw_path}; gdown {url_model})
else:      # from nw.tsuda.ac.jp
    URL_NW = 'https://nw.tsuda.ac.jp/lec/GoogleColab/pub'
    url_model = f'{URL_NW}/models/CycleGAN.py'
    ! wget -nd {url_model} -P {nw_path}
Downloading...
From: https://drive.google.com/uc?id=1aNvpPDNeDWYQFu_PA1kOtFlzcO5seHky
To: /content/nw/CycleGAN.py
100% 24.6k/24.6k [00:00<00:00, 37.7MB/s]
In [9]:
! cat {nw_path}/CycleGAN.py
import tensorflow as tf
import tensorflow_addons as tf_addons
import numpy as np

import matplotlib.pyplot as plt

from collections import deque

import os
import pickle as pkl
import random
import datetime


################################################################################
# Data Loader
################################################################################
class PairDataset():
    def __init__(self, paths_A, paths_B, batch_size= 1, target_size = None, unaligned=False):
        self.paths_A = np.array(paths_A)
        self.paths_B = np.array(paths_B)
        self.target_size = target_size
        self.batch_size = batch_size
        self.unaligned = unaligned

        self.lenA = len(paths_A)
        self.lenB = len(paths_B)
        self.index = 0

    def __len__(self):
        return max(self.lenA, self.lenB)

    def __getitem__(self, index):
        if isinstance(index, slice):
            start, stop, step = index.indices(self.__len__())
            if start == None: start = 0
            if stop == None: stop = self.__len__()
            if step == None:
                if start < stop:
                    step = 1
                elif start > stop:
                    step = -1
                else:
                    step = 0
            return np.array([self.__getitemInt__(i) for i in range(start, stop, step) ])
        else:
            return self.__getitemInt__(index)

    def __getitemInt__(self, index):
        path_A = self.paths_A[index % self.lenA]
        if self.unaligned:
            path_B = self.paths_B[np.random.choice(self.lenB, 1)]
        else:
            path_B = self.paths_B[index % self.lenB]
        img_A = np.array(tf.keras.utils.load_img(path_A, target_size = self.target_size))
        img_B = np.array(tf.keras.utils.load_img(path_B, target_size = self.target_size))
        img_A = (img_A.astype('float32') - 127.5) / 127.5
        img_B = (img_B.astype('float32') - 127.5) / 127.5
        return np.array([img_A, img_B])

    def __next__(self):
        self.index += 1
        return self.__getitem__(self.index-1)



################################################################################
# Layer
################################################################################
class ReflectionPadding2D(tf.keras.layers.Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [tf.keras.layers.InputSpec(ndim=4)]
        super().__init__(**kwargs)

    def compute_output_shape(self, s):
        '''
        If you are using "channels_last" configuration
        '''
        return (s[0], s[1]+2*self.padding[0], s[2]+2*self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad, h_pad = self.padding
        return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')


################################################################################
# Model
################################################################################
class CycleGAN():
    def __init__(
        self,
        input_dim,
        learning_rate,
        lambda_validation,
        lambda_reconstr,
        lambda_id,
        generator_type,
        gen_n_filters,
        disc_n_filters,
        buffer_max_length = 50,
        epoch = 0, 
        d_losses = [],
        g_losses = []
    ):
        self.input_dim = input_dim
        self.learning_rate = learning_rate
        self.buffer_max_length = buffer_max_length
        self.lambda_validation = lambda_validation
        self.lambda_reconstr = lambda_reconstr
        self.lambda_id = lambda_id
        self.generator_type = generator_type
        self.gen_n_filters = gen_n_filters
        self.disc_n_filters = disc_n_filters

        # Input shape
        self.img_rows = input_dim[0]
        self.img_cols = input_dim[1]
        self.channels = input_dim[2]
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
        self.epoch = epoch
        self.d_losses = d_losses
        self.g_losses = g_losses

        self.buffer_A = deque(maxlen=self.buffer_max_length)
        self.buffer_B = deque(maxlen=self.buffer_max_length)

        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**3)
        self.disc_patch = (patch, patch, 1)
        
        self.weight_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
        
        self.compile_models()


    def compile_models(self):
        # Build and compile the discriminators
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()

        self.d_A.compile(
            loss='mse',
            optimizer=tf.keras.optimizers.Adam(self.learning_rate, 0.5),
            metrics=['accuracy']
        )
        self.d_B.compile(
            loss='mse',
            optimizer=tf.keras.optimizers.Adam(self.learning_rate, 0.5),
            metrics=['accuracy']
        )

        # Build the generators
        if self.generator_type == 'unet':
            self.g_AB = self.build_generator_unet()
            self.g_BA = self.build_generator_unet()
        else:
            self.g_AB = self.build_generator_resnet()
            self.g_BA = self.build_generator_resnet()

        # For the combined model we will only train the generators
        self.d_A.trainable = False
        self.d_B.trainable = False

        # Input images from both domains
        img_A = tf.keras.layers.Input(shape=self.img_shape)
        img_B = tf.keras.layers.Input(shape=self.img_shape)

        # Translate images to the other domain
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)

        # translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)

        # Identity mapping of images
        img_A_id = self.g_BA(img_A)   # [self memo] ??? translate *A* from domainB to domainA 
        img_B_id = self.g_AB(img_B)   # [self memo] ??? translate *B* from domainA to domainB 
        
        # Discriminators determines validity of traslated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        # Combined model trains generators to fool discriminators
        self.combined = tf.keras.models.Model(
            inputs=[img_A, img_B],
            outputs=[valid_A, valid_B, reconstr_A, reconstr_B, img_A_id, img_B_id]
        )
        self.combined.compile(
            loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'],  # Mean Squared Error, Mean Absolute Error
            loss_weights=[ self.lambda_validation, self.lambda_validation,
                         self.lambda_reconstr, self.lambda_reconstr,
                         self.lambda_id, self.lambda_id ],
            optimizer=tf.keras.optimizers.Adam(0.0002, 0.5)
        )
        self.d_A.trainable = True
        self.d_B.trainable = True


    def build_generator_unet(self):
        def downsample(layer_input, filters, f_size=4):
            d = tf.keras.layers.Conv2D(
                filters,
                kernel_size=f_size,
                strides=2,
                padding='same',
                kernel_initializer = self.weight_init  # [self memo] added by nitta
            )(layer_input)
            d = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(d)
            d = tf.keras.layers.Activation('relu')(d)
            return d
        def upsample(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            u = tf.keras.layers.UpSampling2D(size=2)(layer_input)
            u = tf.keras.layers.Conv2D(
                filters, 
                kernel_size=f_size, 
                strides=1, 
                padding='same',
                kernel_initializer = self.weight_init  # [self memo] added by nitta
            )(u)
            u = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(u)
            u = tf.keras.layers.Activation('relu')(u)
            if dropout_rate:
                u = tf.keras.layers.Dropout(dropout_rate)(u)
            u = tf.keras.layers.Concatenate()([u, skip_input])
            return u
        # Image input
        img = tf.keras.layers.Input(shape=self.img_shape)
        # Downsampling
        d1 = downsample(img, self.gen_n_filters)
        d2 = downsample(d1, self.gen_n_filters*2)
        d3 = downsample(d2, self.gen_n_filters*4)
        d4 = downsample(d3, self.gen_n_filters*8)

        # Upsampling
        u1 = upsample(d4, d3, self.gen_n_filters*4)
        u2 = upsample(u1, d2, self.gen_n_filters*2)
        u3 = upsample(u2, d1, self.gen_n_filters)

        u4 = tf.keras.layers.UpSampling2D(size=2)(u3)
        output_img = tf.keras.layers.Conv2D(
            self.channels, 
            kernel_size=4,
            strides=1, 
            padding='same',
            activation='tanh',
            kernel_initializer = self.weight_init  # [self memo] added by nitta
        )(u4)

        return tf.keras.models.Model(img, output_img)


    def build_generator_resnet(self):
        def conv7s1(layer_input, filters, final):
            y = ReflectionPadding2D(padding=(3,3))(layer_input)
            y = tf.keras.layers.Conv2D(
                filters,
                kernel_size=(7,7),
                strides=1,
                padding='valid',
                kernel_initializer=self.weight_init
            )(y)
            if final:
                y = tf.keras.layers.Activation('tanh')(y)
            else:
                y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
                y = tf.keras.layers.Activation('relu')(y)
            return y

        def downsample(layer_input, filters):
            y = tf.keras.layers.Conv2D(
                filters, 
                kernel_size=(3,3), 
                strides=2, 
                padding='same',
                kernel_initializer = self.weight_init
            )(layer_input)
            y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            y = tf.keras.layers.Activation('relu')(y)
            return y

        def residual(layer_input, filters):
            shortcut = layer_input
            y = ReflectionPadding2D(padding=(1,1))(layer_input)
            y = tf.keras.layers.Conv2D(
                filters,
                kernel_size=(3,3),
                strides=1,
                padding='valid',
                kernel_initializer=self.weight_init
            )(y)
            y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            y = tf.keras.layers.Activation('relu')(y)
            y = ReflectionPadding2D(padding=(1,1))(y)
            y = tf.keras.layers.Conv2D(
                filters,
                kernel_size=(3,3),
                strides=1,
                padding='valid',
                kernel_initializer=self.weight_init
            )(y)
            y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            return tf.keras.layers.add([shortcut, y])
          
        def upsample(layer_input, filters):
            y = tf.keras.layers.Conv2DTranspose(
                filters, 
                kernel_size=(3,3), 
                strides=2,
                padding='same',
                kernel_initializer=self.weight_init
            )(layer_input)
            y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            y = tf.keras.layers.Activation('relu')(y)
            return y

        # Image input
        img = tf.keras.layers.Input(shape=self.img_shape)

        y = img
        y = conv7s1(y, self.gen_n_filters, False)
        y = downsample(y, self.gen_n_filters * 2)
        y = downsample(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = upsample(y, self.gen_n_filters * 2)
        y = upsample(y, self.gen_n_filters)
        y = conv7s1(y, 3, True)
        output = y
        
        return tf.keras.models.Model(img, output)


    def build_discriminator(self):
        def conv4(layer_input, filters, stride=2, norm=True):
            y = tf.keras.layers.Conv2D(
                filters,
                kernel_size=(4,4),
                strides=stride,
                padding='same',
                kernel_initializer = self.weight_init
              )(layer_input)
            if norm:
                y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            y = tf.keras.layers.LeakyReLU(0.2)(y)
            return y

        img = tf.keras.layers.Input(shape=self.img_shape)
        y = conv4(img, self.disc_n_filters, stride=2, norm=False)
        y = conv4(y, self.disc_n_filters*2, stride=2)
        y = conv4(y, self.disc_n_filters*4, stride=2)
        y = conv4(y, self.disc_n_filters*8, stride=1)
        output = tf.keras.layers.Conv2D(
            1,
            kernel_size=4,
            strides=1,
            padding='same',
            kernel_initializer=self.weight_init
        )(y)
        return tf.keras.models.Model(img, output)


    def train_discriminators(self, imgs_A, imgs_B, valid, fake):
        # Translate images to opposite domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        
        self.buffer_B.append(fake_B)
        self.buffer_A.append(fake_A)

        fake_A_rnd = random.sample(self.buffer_A, min(len(self.buffer_A), len(imgs_A))) # random sampling without replacement 
        fake_B_rnd = random.sample(self.buffer_B, min(len(self.buffer_B), len(imgs_B)))
        
        # Train the discriminators (original images=real / translated = fake)
        dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
        dA_loss_fake = self.d_A.train_on_batch(fake_A_rnd, fake)
        dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

        dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
        dB_loss_fake = self.d_B.train_on_batch(fake_B_rnd, fake)
        dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

        # Total discriminator loss
        d_loss_total = 0.5 * np.add(dA_loss, dB_loss)

        return (
            d_loss_total[0], 
            dA_loss[0], dA_loss_real[0], dA_loss_fake[0],
            dB_loss[0], dB_loss_real[0], dB_loss_fake[0],
            d_loss_total[1], 
            dA_loss[1], dA_loss_real[1], dA_loss_fake[1],
            dB_loss[1], dB_loss_real[1], dB_loss_fake[1]
        )


    def train_generators(self, imgs_A, imgs_B, valid):
        # Train the generators
        return self.combined.train_on_batch(
          [imgs_A, imgs_B], 
          [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B]
        )


    def train(self, data_loader, epochs, batch_size=1, run_folder='./run', print_step_interval=100, save_epoch_interval=100):
        start_time = datetime.datetime.now()
        # Adversarial loss ground truthes
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)

        steps = len(data_loader) // batch_size
        for epoch in range(self.epoch, epochs):
            step_d_losses = []
            step_g_losses = []
            for step in range(steps):
                start = step * batch_size
                end = start + batch_size
                pairs = data_loader[start:end]    # ((a,b), (a, b), ....)
                imgs_A, imgs_B = [], []
                for img_A, img_B in pairs:
                    imgs_A.append(img_A)
                    imgs_B.append(img_B)

                imgs_A = np.array(imgs_A)
                imgs_B = np.array(imgs_B)

                step_d_loss = self.train_discriminators(imgs_A, imgs_B, valid, fake)
                step_g_loss = self.train_generators(imgs_A, imgs_B, valid)

                step_d_losses.append(step_d_loss)
                step_g_losses.append(step_g_loss)

                elapsed_time = datetime.datetime.now() - start_time
                if (step+1) % print_step_interval == 0:
                    print(f'Epoch {epoch+1}/{epochs} {step+1}/{steps} [D loss: {step_d_loss[0]:.3f} acc: {step_d_loss[7]:.3f}][G loss: {step_g_loss[0]:.3f} adv: {np.sum(step_g_loss[1:3]):.3f} recon: {np.sum(step_g_loss[3:5]):.3f} id: {np.sum(step_g_loss[5:7]):.3f} time: {elapsed_time:}')

            d_loss = np.mean(step_d_losses, axis=0)
            g_loss = np.mean(step_g_losses, axis=0)

            elapsed_time = datetime.datetime.now() - start_time

            elapsed_time = datetime.datetime.now() - start_time
            print(f'Epoch {epoch+1}/{epochs} [D loss: {d_loss[0]:.3f} acc: {d_loss[7]:.3f}][G loss: {g_loss[0]:.3f} adv: {np.sum(g_loss[1:3]):.3f} recon: {np.sum(g_loss[3:5]):.3f} id: {np.sum(g_loss[5:7]):.3f} time: {elapsed_time:}')
                    
            self.d_losses.append(d_loss)
            self.g_losses.append(g_loss)

            self.epoch += 1
            if (self.epoch) % save_epoch_interval == 0:
                self.save(run_folder, self.epoch)
                self.save(run_folder)

        self.save(run_folder, self.epoch)
        self.save(run_folder)


    def save(self, folder, epoch=None):
        self.save_params(folder, epoch)
        self.save_weights(folder,epoch)


    @staticmethod
    def load(folder, epoch=None):
        params = CycleGAN.load_params(folder, epoch)
        gan = CycleGAN(*params)
        gan.load_weights(folder, epoch)
        return gan


    def save_weights(self, run_folder, epoch=None):
        if epoch is None:
            self.save_model_weights(self.combined, os.path.join(run_folder, 'weights/combined-weights.h5'))
            self.save_model_weights(self.d_A, os.path.join(run_folder, 'weights/d_A-weights.h5'))
            self.save_model_weights(self.d_B, os.path.join(run_folder, 'weights/d_B-weights.h5'))
            self.save_model_weights(self.g_AB, os.path.join(run_folder, 'weights/g_AB-weights.h5'))
            self.save_model_weights(self.g_BA, os.path.join(run_folder, 'weights/g_BA-weights.h5'))
        else:
            self.save_model_weights(self.combined, os.path.join(run_folder, f'weights/combined-weights_{epoch}.h5'))
            self.save_model_weights(self.d_A, os.path.join(run_folder, f'weights/d_A-weights_{epoch}.h5'))
            self.save_model_weights(self.d_B, os.path.join(run_folder, f'weights/d_B-weights_{epoch}.h5'))
            self.save_model_weights(self.g_AB, os.path.join(run_folder, f'weights/g_AB-weights_{epoch}.h5'))
            self.save_model_weights(self.g_BA, os.path.join(run_folder, f'weights/g_BA-weights_{epoch}.h5'))


    def load_weights(self, run_folder, epoch=None):
        if epoch is None:
            self.load_model_weights(self.combined, os.path.join(run_folder, 'weights/combined-weights.h5'))
            self.load_model_weights(self.d_A, os.path.join(run_folder, 'weights/d_A-weights.h5'))
            self.load_model_weights(self.d_B, os.path.join(run_folder, 'weights/d_B-weights.h5'))
            self.load_model_weights(self.g_AB, os.path.join(run_folder, 'weights/g_AB-weights.h5'))
            self.load_model_weights(self.g_BA, os.path.join(run_folder, 'weights/g_BA-weights.h5'))
        else:
            self.load_model_weights(self.combined, os.path.join(run_folder, f'weights/combined-weights_{epoch}.h5'))
            self.load_model_weights(self.d_A, os.path.join(run_folder, f'weights/d_A-weights_{epoch}.h5'))
            self.load_model_weights(self.d_B, os.path.join(run_folder, f'weights/d_B-weights_{epoch}.h5'))
            self.load_model_weights(self.g_AB, os.path.join(run_folder, f'weights/g_AB-weights_{epoch}.h5'))
            self.load_model_weights(self.g_BA, os.path.join(run_folder, f'weights/g_BA-weights_{epoch}.h5'))


    def save_model_weights(self, model, filepath):
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        model.save_weights(filepath)


    def load_model_weights(self, model, filepath):
        model.load_weights(filepath)


    def save_params(self, folder, epoch=None):
        if epoch is None:
            filepath = os.path.join(folder, 'params.pkl')
        else:
            filepath = os.path.join(folder, f'params_{epoch}.pkl')

        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)

        with open(filepath, 'wb') as f:
            pkl.dump([
                self.input_dim,
                self.learning_rate,
                self.lambda_validation,
                self.lambda_reconstr,
                self.lambda_id,
                self.generator_type,
                self.gen_n_filters,
                self.disc_n_filters,
                self.buffer_max_length,
                self.epoch,
                self.d_losses,
                self.g_losses
              ], f)


    @staticmethod
    def load_params(folder, epoch=None):
        if epoch is None:
            filepath = os.path.join(folder, 'params.pkl')
        else:
            filepath = os.path.join(folder, f'params_{epoch}.pkl')

        with open(filepath, 'rb') as f:
            params = pkl.load(f)
        return params


    def generate_image(self, img_A, img_B):
        gen_A = self.generate_image_from_A(img_A)
        gen_B = self.generate_image_from_B(img_B)
        return np.concatenate([gen_A, gen_B], axis=0)


    def generate_image_from_A(self, img_A):
        fake_B = self.g_AB.predict(img_A)      # Translate images to the other domain
        reconstr_A = self.g_BA.predict(fake_B)  # Translate back to original domain
        id_A = self.g_BA.predict(img_A)    # ID the images
        return np.concatenate([img_A, fake_B, reconstr_A, id_A])


    def generate_image_from_B(self, img_B):
        fake_A = self.g_BA.predict(img_B)
        reconstr_B = self.g_AB.predict(fake_A)
        id_B = self.g_AB.predict(img_B)
        return np.concatenate([img_B, fake_A, reconstr_B, id_B])


    @staticmethod
    def showImages(imgs, trans, recon, idimg, w=2.8, h=2.8, filepath=None):
        N = len(imgs)
        M = len(imgs[0])
        titles = ['Original', 'Translated', 'Reconstructed', 'ID']

        fig, ax = plt.subplots(N, M, figsize=(w*M, h*N))
        for i in range(N):
            for j in range(M):
                ax[i][j].imshow(imgs[i][j])
                ax[i][j].set_title(title[j])
                ax[i][j].axis('off')

        if not filepath is None:
            dpath, fname = os.path.split(filepath)
            if dpath != '' and not os.path.exists(dpath):
                os.makedirs(dpath)
            fig.savefig(filepath, dpi=600)
            plt.close()
        else:
            plt.show()
        

    def showLoss(self, xlim=[], ylim=[]):
        print('loss AB')
        self.showLossAB(xlim, ylim)
        print('loss BA')
        self.showLossBA(xlim, ylim)


    def showLossAB(self, xlim=[], ylim=[]):
        g = np.array(self.g_losses)
        g_loss = g[:, 0]
        g_adv = g[:, 1]
        g_recon = g[:, 3]
        g_id = g[:, 5]
        CycleGAN.plot_history(
            [g_loss, g_adv, g_recon, g_id],
            ['g_loss', 'AB discrim', 'AB cycle', 'AB id'],
            xlim,
            ylim)

    def showLossBA(self, xlim=[], ylim=[]):
        g = np.array(self.g_losses)
        g_loss = g[:, 0]
        g_adv = g[:, 2]
        g_recon = g[:, 4]
        g_id = g[:, 6]
        CycleGAN.plot_history(
            [g_loss, g_adv, g_recon, g_id],
            ['g_loss', 'BA discrim', 'BA cycle', 'BA id'],
            xlim,
            ylim)


    @staticmethod
    def plot_history(vals, labels, xlim=[], ylim=[]):
        colors = ['red', 'blue', 'green', 'orange', 'black', 'pink']
        n = len(vals)
        fig, ax = plt.subplots(1, 1, figsize=(12,6))
        for i in range(n):
            ax.plot(vals[i], c=colors[i], label=labels[i])
        ax.legend(loc='upper right')
        ax.set_xlabel('epochs')
        # ax.set_ylabel('loss')

        if xlim != []:
            ax.set_xlim(xlim[0], xlim[1])
        if ylim != []:
            ax.set_ylim(ylim[0], ylim[1])
        
        plt.show()

Preparing VidTIMIT dataset

Official WWW of VidTIMIT dataset: http://conradsanderson.id.au/vidtimit/

zip files of 2 persons of VidTIMIT dataset:
https://zenodo.org/record/158963/files/fadg0.zip
https://zenodo.org/record/158963/files/faks0.zip

zip files mirrored on my Google Drive:
https://drive.google.com/uc?id=

VidTIMIT データセットを用意する

VidTIMIT データセットの公式ページ: http://conradsanderson.id.au/vidtimit/

VidTIMIT の2名の顔写真の zip ファイル:
https://zenodo.org/record/158963/files/fadg0.zip
https://zenodo.org/record/158963/files/faks0.zip

自分の Google Drive 上にミラーした顔写真:
https://drive.google.com/uc?id= https://drive.google.com/uc?id=

In [10]:
# Download zip files
VidTIMIT_site = 'https://zenodo.org/record/158963/files/'
VidTIMIT_fnames = [ 'fadg0', 'faks0']

Mirrored_files = [
    'https://drive.google.com/uc?id=1_Fv4p9MDNphMZMnLpEvtCtnwXgN8N5Cj', 
    'https://drive.google.com/uc?id=1Y8j7ThPVqB0gbx4hb9aMEp9Ptr9wFuoz'
]

data_dir = './datasets'
! rm -rf $data_dir
! mkdir -p $data_dir

for i, fname in enumerate(VidTIMIT_fnames):
    fzip = fname + '.zip'
    if False:
        url = VidTIMIT_site + fzip
        !wget {url}
    else:
        url = Mirrored_files[i]
        !gdown {url}

    !unzip -q {fzip} -d {data_dir}
Downloading...
From: https://drive.google.com/uc?id=1_Fv4p9MDNphMZMnLpEvtCtnwXgN8N5Cj
To: /content/fadg0.zip
100% 81.6M/81.6M [00:00<00:00, 219MB/s]
Downloading...
From: https://drive.google.com/uc?id=1Y8j7ThPVqB0gbx4hb9aMEp9Ptr9wFuoz
To: /content/faks0.zip
100% 64.2M/64.2M [00:00<00:00, 174MB/s]
In [11]:
! ls {data_dir}
fadg0  faks0
In [12]:
! ls {data_dir}/fadg0
audio  video
In [13]:
! ls {data_dir}/fadg0/video
head   head3  sa2     si1909  sx109  sx199  sx379
head2  sa1    si1279  si649   sx19   sx289
In [14]:
! ls {data_dir}/fadg0/video/head | head -5
001
002
003
004
005
In [15]:
! ls {data_dir}/fadg0/video/sx19 | head -5
001
002
003
004
005

Check the images of VidTIMIT

VidTIMIT の画像をチェックする

In [16]:
# Display images
# 画像を表示する。
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

def showImages(imgs, rows=-1, cols=-1, w=2, h=2):
    N = len(imgs)
    if rows < 0: rows = 1
    if cols < 0: cols = (N + rows -1) // rows
    fig, ax = plt.subplots(rows, cols, figsize=(w*cols, h*rows))
    idx = 0
    for row in range(rows):
        for col in range(cols) :
            if rows == 1 and cols == 1:
                axis = ax
            elif rows == 1:
                axis = ax[col]
            elif cols == 1:
                axis = ax[row]
            else:
                axis = ax[row][col]

            if idx < N:
                axis.imshow(imgs[idx])
            axis.axis('off')
            idx += 1
    plt.show()
In [17]:
# Display images with the file paths.
# 画像のpathの配列を受け取って、画像を表示する。
%matplotlib inline
import tensorflow as tf

def showImagesByPath(fnames,rows=-1, cols=-1, w=2, h=2):
    imgs = [ tf.keras.utils.load_img(fname) for fname in fnames]
    showImages(imgs, rows, cols, w, h)
In [18]:
import os
import glob

! ls {data_dir}/{VidTIMIT_fnames[0]}/video
head   head3  sa2     si1909  sx109  sx199  sx379
head2  sa1    si1279  si649   sx19   sx289
In [19]:
! ls {data_dir}
fadg0  faks0
In [20]:
# Images of A
person = VidTIMIT_fnames[0]

print(f'===={person}====')
paths = glob.glob(os.path.join(f'{data_dir}/{person}/video', '*'))
for path in paths:
    dpath, dir = os.path.split(path)
    fpath = glob.glob(os.path.join(path, '*'))
    print(f'{dir}  {len(fpath)}')
    showImagesByPath(fpath[:5],-1,-1,1,1)
====fadg0====
sx289  126
sx19  97
si1279  72
si1909  117
sa1  119
sx109  135
sx199  144
head2  406
head3  742
si649  216
sx379  109
sa2  103
head  346
In [21]:
# Images of B
person = VidTIMIT_fnames[1]

print(f'===={person}====')
paths = glob.glob(os.path.join(f'{data_dir}/{person}/video', '*'))
for path in paths:
    dpath, dir = os.path.split(path)
    fpath = glob.glob(os.path.join(path, '*'))
    print(f'{dir}  {len(fpath)}')
    showImagesByPath(fpath[:5],-1,-1,1,1)
====faks0====
sx133  97
sa1  89
sx313  90
head2  358
si2203  81
head3  369
sx403  97
sa2  81
sx223  89
si1573  145
head  458
si943  98
sx43  86

Make DataGenerator from the images of VidTIMIT

VidTIMIT の画像ファイルから DataGenerator を作る

In [22]:
IMAGE_SIZE = 128
In [23]:
import os
import glob

imgA_paths = glob.glob(os.path.join(data_dir, VidTIMIT_fnames[0], 'video/*/[0-9]*'))
imgB_paths = glob.glob(os.path.join(data_dir, VidTIMIT_fnames[1], 'video/*/[0-9]*'))
In [24]:
import numpy as np

validation_split = 0.05

nA, nB = len(imgA_paths), len(imgB_paths)
splitA = int(nA * (1 - validation_split))
splitB = int(nB * (1 - validation_split))

np.random.shuffle(imgA_paths)
np.random.shuffle(imgB_paths)

train_imgA_paths = imgA_paths[:splitA]
test_imgA_paths = imgA_paths[splitA:]
train_imgB_paths = imgB_paths[:splitB]
test_imgB_paths = imgB_paths[splitB:]
In [25]:
# Image: [-1, 1] --> [0, 1]
def M1P1_ZeroP1(imgs):
    imgs = (imgs + 1) * 0.5
    return np.clip(imgs, 0, 1)

# Image: [0, 1] --> [-1, 1]
def ZeroP1_M1P1(imgs):
    return imgs * 2 - 1
In [26]:
from nw.CycleGAN import PairDataset

pair_flow = PairDataset(train_imgA_paths, train_imgB_paths, target_size=(IMAGE_SIZE, IMAGE_SIZE))
test_pair_flow = PairDataset(test_imgA_paths, test_imgB_paths, target_size=(IMAGE_SIZE, IMAGE_SIZE))
In [27]:
a, b = next(pair_flow)
print(a.shape, b.shape)
showImages([M1P1_ZeroP1(a), M1P1_ZeroP1(b)])
(128, 128, 3) (128, 128, 3)
In [28]:
pairs = pair_flow[10:15]
shape = pairs.shape
print(shape)

a = pairs.reshape(-1, *shape[2:])
print(a.shape)
showImages(M1P1_ZeroP1(a), 5, 2)
(5, 2, 128, 128, 3)
(10, 128, 128, 3)

Define the Neural Network Model

ニューラルネットワーク・モデルを定義する

In [29]:
from nw.CycleGAN import CycleGAN

gan = CycleGAN(
    input_dim = (IMAGE_SIZE, IMAGE_SIZE, 3),
    learning_rate = 0.0002,
    buffer_max_length = 50,
    lambda_validation = 1,
    lambda_reconstr = 10,
    lambda_id = 2,
    generator_type = 'unet',
    gen_n_filters = 32,
    disc_n_filters = 32
)
In [30]:
gan.d_A.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 128, 128, 3)]     0         
                                                                 
 conv2d (Conv2D)             (None, 64, 64, 32)        1568      
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 64, 64, 32)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 32, 32, 64)        32832     
                                                                 
 instance_normalization (Ins  (None, 32, 32, 64)       0         
 tanceNormalization)                                             
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 32, 32, 64)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 16, 16, 128)       131200    
                                                                 
 instance_normalization_1 (I  (None, 16, 16, 128)      0         
 nstanceNormalization)                                           
                                                                 
 leaky_re_lu_2 (LeakyReLU)   (None, 16, 16, 128)       0         
                                                                 
 conv2d_3 (Conv2D)           (None, 16, 16, 256)       524544    
                                                                 
 instance_normalization_2 (I  (None, 16, 16, 256)      0         
 nstanceNormalization)                                           
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 16, 16, 256)       0         
                                                                 
 conv2d_4 (Conv2D)           (None, 16, 16, 1)         4097      
                                                                 
=================================================================
Total params: 694,241
Trainable params: 694,241
Non-trainable params: 0
_________________________________________________________________
In [31]:
gan.d_B.summary()
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 128, 128, 3)]     0         
                                                                 
 conv2d_5 (Conv2D)           (None, 64, 64, 32)        1568      
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 64, 64, 32)        0         
                                                                 
 conv2d_6 (Conv2D)           (None, 32, 32, 64)        32832     
                                                                 
 instance_normalization_3 (I  (None, 32, 32, 64)       0         
 nstanceNormalization)                                           
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 32, 32, 64)        0         
                                                                 
 conv2d_7 (Conv2D)           (None, 16, 16, 128)       131200    
                                                                 
 instance_normalization_4 (I  (None, 16, 16, 128)      0         
 nstanceNormalization)                                           
                                                                 
 leaky_re_lu_6 (LeakyReLU)   (None, 16, 16, 128)       0         
                                                                 
 conv2d_8 (Conv2D)           (None, 16, 16, 256)       524544    
                                                                 
 instance_normalization_5 (I  (None, 16, 16, 256)      0         
 nstanceNormalization)                                           
                                                                 
 leaky_re_lu_7 (LeakyReLU)   (None, 16, 16, 256)       0         
                                                                 
 conv2d_9 (Conv2D)           (None, 16, 16, 1)         4097      
                                                                 
=================================================================
Total params: 694,241
Trainable params: 694,241
Non-trainable params: 0
_________________________________________________________________
In [32]:
gan.g_AB.summary()
Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_3 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_10 (Conv2D)             (None, 64, 64, 32)   1568        ['input_3[0][0]']                
                                                                                                  
 instance_normalization_6 (Inst  (None, 64, 64, 32)  0           ['conv2d_10[0][0]']              
 anceNormalization)                                                                               
                                                                                                  
 activation (Activation)        (None, 64, 64, 32)   0           ['instance_normalization_6[0][0]'
                                                                 ]                                
                                                                                                  
 conv2d_11 (Conv2D)             (None, 32, 32, 64)   32832       ['activation[0][0]']             
                                                                                                  
 instance_normalization_7 (Inst  (None, 32, 32, 64)  0           ['conv2d_11[0][0]']              
 anceNormalization)                                                                               
                                                                                                  
 activation_1 (Activation)      (None, 32, 32, 64)   0           ['instance_normalization_7[0][0]'
                                                                 ]                                
                                                                                                  
 conv2d_12 (Conv2D)             (None, 16, 16, 128)  131200      ['activation_1[0][0]']           
                                                                                                  
 instance_normalization_8 (Inst  (None, 16, 16, 128)  0          ['conv2d_12[0][0]']              
 anceNormalization)                                                                               
                                                                                                  
 activation_2 (Activation)      (None, 16, 16, 128)  0           ['instance_normalization_8[0][0]'
                                                                 ]                                
                                                                                                  
 conv2d_13 (Conv2D)             (None, 8, 8, 256)    524544      ['activation_2[0][0]']           
                                                                                                  
 instance_normalization_9 (Inst  (None, 8, 8, 256)   0           ['conv2d_13[0][0]']              
 anceNormalization)                                                                               
                                                                                                  
 activation_3 (Activation)      (None, 8, 8, 256)    0           ['instance_normalization_9[0][0]'
                                                                 ]                                
                                                                                                  
 up_sampling2d (UpSampling2D)   (None, 16, 16, 256)  0           ['activation_3[0][0]']           
                                                                                                  
 conv2d_14 (Conv2D)             (None, 16, 16, 128)  524416      ['up_sampling2d[0][0]']          
                                                                                                  
 instance_normalization_10 (Ins  (None, 16, 16, 128)  0          ['conv2d_14[0][0]']              
 tanceNormalization)                                                                              
                                                                                                  
 activation_4 (Activation)      (None, 16, 16, 128)  0           ['instance_normalization_10[0][0]
                                                                 ']                               
                                                                                                  
 concatenate (Concatenate)      (None, 16, 16, 256)  0           ['activation_4[0][0]',           
                                                                  'activation_2[0][0]']           
                                                                                                  
 up_sampling2d_1 (UpSampling2D)  (None, 32, 32, 256)  0          ['concatenate[0][0]']            
                                                                                                  
 conv2d_15 (Conv2D)             (None, 32, 32, 64)   262208      ['up_sampling2d_1[0][0]']        
                                                                                                  
 instance_normalization_11 (Ins  (None, 32, 32, 64)  0           ['conv2d_15[0][0]']              
 tanceNormalization)                                                                              
                                                                                                  
 activation_5 (Activation)      (None, 32, 32, 64)   0           ['instance_normalization_11[0][0]
                                                                 ']                               
                                                                                                  
 concatenate_1 (Concatenate)    (None, 32, 32, 128)  0           ['activation_5[0][0]',           
                                                                  'activation_1[0][0]']           
                                                                                                  
 up_sampling2d_2 (UpSampling2D)  (None, 64, 64, 128)  0          ['concatenate_1[0][0]']          
                                                                                                  
 conv2d_16 (Conv2D)             (None, 64, 64, 32)   65568       ['up_sampling2d_2[0][0]']        
                                                                                                  
 instance_normalization_12 (Ins  (None, 64, 64, 32)  0           ['conv2d_16[0][0]']              
 tanceNormalization)                                                                              
                                                                                                  
 activation_6 (Activation)      (None, 64, 64, 32)   0           ['instance_normalization_12[0][0]
                                                                 ']                               
                                                                                                  
 concatenate_2 (Concatenate)    (None, 64, 64, 64)   0           ['activation_6[0][0]',           
                                                                  'activation[0][0]']             
                                                                                                  
 up_sampling2d_3 (UpSampling2D)  (None, 128, 128, 64  0          ['concatenate_2[0][0]']          
                                )                                                                 
                                                                                                  
 conv2d_17 (Conv2D)             (None, 128, 128, 3)  3075        ['up_sampling2d_3[0][0]']        
                                                                                                  
==================================================================================================
Total params: 1,545,411
Trainable params: 1,545,411
Non-trainable params: 0
__________________________________________________________________________________________________
In [33]:
gan.g_BA.summary()
Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_4 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_18 (Conv2D)             (None, 64, 64, 32)   1568        ['input_4[0][0]']                
                                                                                                  
 instance_normalization_13 (Ins  (None, 64, 64, 32)  0           ['conv2d_18[0][0]']              
 tanceNormalization)                                                                              
                                                                                                  
 activation_7 (Activation)      (None, 64, 64, 32)   0           ['instance_normalization_13[0][0]
                                                                 ']                               
                                                                                                  
 conv2d_19 (Conv2D)             (None, 32, 32, 64)   32832       ['activation_7[0][0]']           
                                                                                                  
 instance_normalization_14 (Ins  (None, 32, 32, 64)  0           ['conv2d_19[0][0]']              
 tanceNormalization)                                                                              
                                                                                                  
 activation_8 (Activation)      (None, 32, 32, 64)   0           ['instance_normalization_14[0][0]
                                                                 ']                               
                                                                                                  
 conv2d_20 (Conv2D)             (None, 16, 16, 128)  131200      ['activation_8[0][0]']           
                                                                                                  
 instance_normalization_15 (Ins  (None, 16, 16, 128)  0          ['conv2d_20[0][0]']              
 tanceNormalization)                                                                              
                                                                                                  
 activation_9 (Activation)      (None, 16, 16, 128)  0           ['instance_normalization_15[0][0]
                                                                 ']                               
                                                                                                  
 conv2d_21 (Conv2D)             (None, 8, 8, 256)    524544      ['activation_9[0][0]']           
                                                                                                  
 instance_normalization_16 (Ins  (None, 8, 8, 256)   0           ['conv2d_21[0][0]']              
 tanceNormalization)                                                                              
                                                                                                  
 activation_10 (Activation)     (None, 8, 8, 256)    0           ['instance_normalization_16[0][0]
                                                                 ']                               
                                                                                                  
 up_sampling2d_4 (UpSampling2D)  (None, 16, 16, 256)  0          ['activation_10[0][0]']          
                                                                                                  
 conv2d_22 (Conv2D)             (None, 16, 16, 128)  524416      ['up_sampling2d_4[0][0]']        
                                                                                                  
 instance_normalization_17 (Ins  (None, 16, 16, 128)  0          ['conv2d_22[0][0]']              
 tanceNormalization)                                                                              
                                                                                                  
 activation_11 (Activation)     (None, 16, 16, 128)  0           ['instance_normalization_17[0][0]
                                                                 ']                               
                                                                                                  
 concatenate_3 (Concatenate)    (None, 16, 16, 256)  0           ['activation_11[0][0]',          
                                                                  'activation_9[0][0]']           
                                                                                                  
 up_sampling2d_5 (UpSampling2D)  (None, 32, 32, 256)  0          ['concatenate_3[0][0]']          
                                                                                                  
 conv2d_23 (Conv2D)             (None, 32, 32, 64)   262208      ['up_sampling2d_5[0][0]']        
                                                                                                  
 instance_normalization_18 (Ins  (None, 32, 32, 64)  0           ['conv2d_23[0][0]']              
 tanceNormalization)                                                                              
                                                                                                  
 activation_12 (Activation)     (None, 32, 32, 64)   0           ['instance_normalization_18[0][0]
                                                                 ']                               
                                                                                                  
 concatenate_4 (Concatenate)    (None, 32, 32, 128)  0           ['activation_12[0][0]',          
                                                                  'activation_8[0][0]']           
                                                                                                  
 up_sampling2d_6 (UpSampling2D)  (None, 64, 64, 128)  0          ['concatenate_4[0][0]']          
                                                                                                  
 conv2d_24 (Conv2D)             (None, 64, 64, 32)   65568       ['up_sampling2d_6[0][0]']        
                                                                                                  
 instance_normalization_19 (Ins  (None, 64, 64, 32)  0           ['conv2d_24[0][0]']              
 tanceNormalization)                                                                              
                                                                                                  
 activation_13 (Activation)     (None, 64, 64, 32)   0           ['instance_normalization_19[0][0]
                                                                 ']                               
                                                                                                  
 concatenate_5 (Concatenate)    (None, 64, 64, 64)   0           ['activation_13[0][0]',          
                                                                  'activation_7[0][0]']           
                                                                                                  
 up_sampling2d_7 (UpSampling2D)  (None, 128, 128, 64  0          ['concatenate_5[0][0]']          
                                )                                                                 
                                                                                                  
 conv2d_25 (Conv2D)             (None, 128, 128, 3)  3075        ['up_sampling2d_7[0][0]']        
                                                                                                  
==================================================================================================
Total params: 1,545,411
Trainable params: 1,545,411
Non-trainable params: 0
__________________________________________________________________________________________________
In [34]:
gan.combined.summary()
Model: "model_4"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_6 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                )]                                                                
                                                                                                  
 input_5 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                )]                                                                
                                                                                                  
 model_3 (Functional)           (None, 128, 128, 3)  1545411     ['input_6[0][0]',                
                                                                  'model_2[0][0]',                
                                                                  'input_5[0][0]']                
                                                                                                  
 model_2 (Functional)           (None, 128, 128, 3)  1545411     ['input_5[0][0]',                
                                                                  'model_3[0][0]',                
                                                                  'input_6[0][0]']                
                                                                                                  
 model (Functional)             (None, 16, 16, 1)    694241      ['model_3[0][0]']                
                                                                                                  
 model_1 (Functional)           (None, 16, 16, 1)    694241      ['model_2[0][0]']                
                                                                                                  
==================================================================================================
Total params: 4,479,304
Trainable params: 4,479,304
Non-trainable params: 0
__________________________________________________________________________________________________

Train

訓練する

In [35]:
save_path = '/content/drive/MyDrive/ColabRun/CycleGAN_VidTIMIT01'
In [36]:
! rm -rf {save_path}

Train for a few epochs

少ないエポック回数だけ学習させる

In [37]:
gan.train(
    pair_flow,
    epochs=1,
    batch_size=1,
    run_folder = save_path
)
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_train_function.<locals>.train_function at 0x7f417cbbd3b0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Epoch 1/1 100/2595 [D loss: 0.257 acc: 0.555][G loss: 2.595 adv: 0.902 recon: 0.138 id: 0.156 time: 0:00:29.597120
Epoch 1/1 200/2595 [D loss: 0.156 acc: 0.812][G loss: 2.454 adv: 0.485 recon: 0.163 id: 0.167 time: 0:00:48.062442
Epoch 1/1 300/2595 [D loss: 0.202 acc: 0.663][G loss: 2.156 adv: 0.726 recon: 0.118 id: 0.124 time: 0:01:06.676460
Epoch 1/1 400/2595 [D loss: 0.123 acc: 0.864][G loss: 1.798 adv: 0.463 recon: 0.112 id: 0.110 time: 0:01:25.102150
Epoch 1/1 500/2595 [D loss: 0.126 acc: 0.889][G loss: 1.966 adv: 0.569 recon: 0.115 id: 0.123 time: 0:01:43.598282
Epoch 1/1 600/2595 [D loss: 0.123 acc: 0.874][G loss: 2.198 adv: 0.911 recon: 0.109 id: 0.099 time: 0:02:02.081157
Epoch 1/1 700/2595 [D loss: 0.207 acc: 0.704][G loss: 2.535 adv: 0.853 recon: 0.142 id: 0.130 time: 0:02:20.621453
Epoch 1/1 800/2595 [D loss: 0.145 acc: 0.797][G loss: 2.127 adv: 0.800 recon: 0.113 id: 0.101 time: 0:02:38.969913
Epoch 1/1 900/2595 [D loss: 0.205 acc: 0.652][G loss: 1.895 adv: 0.474 recon: 0.120 id: 0.113 time: 0:02:57.377701
Epoch 1/1 1000/2595 [D loss: 0.180 acc: 0.731][G loss: 2.315 adv: 1.155 recon: 0.099 id: 0.085 time: 0:03:15.904667
Epoch 1/1 1100/2595 [D loss: 0.273 acc: 0.708][G loss: 1.403 adv: 0.320 recon: 0.092 id: 0.083 time: 0:03:34.858344
Epoch 1/1 1200/2595 [D loss: 0.304 acc: 0.553][G loss: 2.397 adv: 0.789 recon: 0.138 id: 0.115 time: 0:03:53.339480
Epoch 1/1 1300/2595 [D loss: 0.188 acc: 0.739][G loss: 2.784 adv: 1.287 recon: 0.127 id: 0.116 time: 0:04:11.640668
Epoch 1/1 1400/2595 [D loss: 0.081 acc: 0.972][G loss: 3.028 adv: 1.809 recon: 0.103 id: 0.095 time: 0:04:30.422425
Epoch 1/1 1500/2595 [D loss: 0.257 acc: 0.597][G loss: 1.916 adv: 0.826 recon: 0.092 id: 0.085 time: 0:04:49.425767
Epoch 1/1 1600/2595 [D loss: 0.173 acc: 0.737][G loss: 1.962 adv: 0.828 recon: 0.097 id: 0.084 time: 0:05:07.955474
Epoch 1/1 1700/2595 [D loss: 0.204 acc: 0.761][G loss: 1.540 adv: 0.392 recon: 0.098 id: 0.085 time: 0:05:26.161751
Epoch 1/1 1800/2595 [D loss: 0.207 acc: 0.646][G loss: 3.030 adv: 1.701 recon: 0.111 id: 0.108 time: 0:05:44.512333
Epoch 1/1 1900/2595 [D loss: 0.166 acc: 0.789][G loss: 1.769 adv: 0.498 recon: 0.108 id: 0.096 time: 0:06:03.126617
Epoch 1/1 2000/2595 [D loss: 0.229 acc: 0.592][G loss: 2.193 adv: 0.814 recon: 0.116 id: 0.108 time: 0:06:21.699297
Epoch 1/1 2100/2595 [D loss: 0.129 acc: 0.847][G loss: 1.766 adv: 0.460 recon: 0.112 id: 0.095 time: 0:06:40.152807
Epoch 1/1 2200/2595 [D loss: 0.131 acc: 0.831][G loss: 1.249 adv: 0.289 recon: 0.082 id: 0.069 time: 0:06:58.470594
Epoch 1/1 2300/2595 [D loss: 0.080 acc: 0.921][G loss: 2.083 adv: 0.825 recon: 0.106 id: 0.099 time: 0:07:17.008799
Epoch 1/1 2400/2595 [D loss: 0.142 acc: 0.816][G loss: 2.015 adv: 0.782 recon: 0.104 id: 0.095 time: 0:07:35.327724
Epoch 1/1 2500/2595 [D loss: 0.214 acc: 0.676][G loss: 2.067 adv: 0.967 recon: 0.094 id: 0.081 time: 0:07:53.849333
Epoch 1/1 [D loss: 0.166 acc: 0.779][G loss: 2.098 adv: 0.744 recon: 0.114 id: 0.107 time: 0:08:11.305807
In [38]:
! ls {save_path}/weights
combined-weights_1.h5  d_A-weights.h5	 g_AB-weights_1.h5  g_BA-weights.h5
combined-weights.h5    d_B-weights_1.h5  g_AB-weights.h5
d_A-weights_1.h5       d_B-weights.h5	 g_BA-weights_1.h5
In [39]:
gan.train(
    pair_flow,
    epochs=3,
    batch_size=1,
    run_folder = save_path,
    print_step_interval=500
)
Epoch 2/3 500/2595 [D loss: 0.145 acc: 0.794][G loss: 2.187 adv: 0.990 recon: 0.102 id: 0.089 time: 0:01:33.490890
Epoch 2/3 1000/2595 [D loss: 0.046 acc: 0.961][G loss: 1.614 adv: 0.548 recon: 0.091 id: 0.077 time: 0:03:05.658150
Epoch 2/3 1500/2595 [D loss: 0.035 acc: 0.999][G loss: 1.228 adv: 0.295 recon: 0.079 id: 0.073 time: 0:04:38.242862
Epoch 2/3 2000/2595 [D loss: 0.086 acc: 0.928][G loss: 1.748 adv: 0.481 recon: 0.107 id: 0.099 time: 0:06:10.717727
Epoch 2/3 2500/2595 [D loss: 0.125 acc: 0.853][G loss: 1.760 adv: 0.658 recon: 0.094 id: 0.084 time: 0:07:43.549074
Epoch 2/3 [D loss: 0.112 acc: 0.869][G loss: 1.857 adv: 0.764 recon: 0.093 id: 0.084 time: 0:08:00.853371
Epoch 3/3 500/2595 [D loss: 0.091 acc: 0.905][G loss: 1.638 adv: 0.532 recon: 0.092 id: 0.092 time: 0:09:32.711431
Epoch 3/3 1000/2595 [D loss: 0.094 acc: 0.897][G loss: 2.009 adv: 1.012 recon: 0.084 id: 0.079 time: 0:11:05.072448
Epoch 3/3 1500/2595 [D loss: 0.117 acc: 0.829][G loss: 1.813 adv: 0.951 recon: 0.072 id: 0.072 time: 0:12:38.222414
Epoch 3/3 2000/2595 [D loss: 0.146 acc: 0.803][G loss: 2.120 adv: 0.872 recon: 0.105 id: 0.098 time: 0:14:10.363498
Epoch 3/3 2500/2595 [D loss: 0.031 acc: 1.000][G loss: 1.634 adv: 0.632 recon: 0.084 id: 0.081 time: 0:15:42.125561
Epoch 3/3 [D loss: 0.098 acc: 0.891][G loss: 1.870 adv: 0.837 recon: 0.087 id: 0.082 time: 0:15:59.501885

Generate Images

画像を生成する

In [40]:
# Display generated and cycle images.
# 生成画像とサイクル画像を表示する。

test_pairs = test_pair_flow[:5]

test_imgsA = test_pairs[:,0]
test_imgsB = test_pairs[:,1]

imgsAB = gan.generate_image_from_A(test_imgsA)
imgsBA = gan.generate_image_from_B(test_imgsB)

print('A-->B-->A, ID')
showImages(M1P1_ZeroP1(imgsAB), 4)

print('B-->A-->B, ID')
showImages(M1P1_ZeroP1(imgsBA), 4)
A-->B-->A, ID
B-->A-->B, ID

Check the loss and accuracy of the training process.

学習過程のlossと精度を確認する

In [41]:
# Display the graph of losses in training
%matplotlib inline

gan.showLoss()
loss AB
loss BA

Check the saved files

保存されているファイルを確認する

In [42]:
! ls -lR {save_path}
/content/drive/MyDrive/ColabRun/CycleGAN_VidTIMIT01:
total 7
-rw------- 1 root root  413 Nov 29 12:52 params_1.pkl
-rw------- 1 root root  895 Nov 29 13:08 params_3.pkl
-rw------- 1 root root  895 Nov 29 13:08 params.pkl
drwx------ 2 root root 4096 Nov 29 13:08 weights

/content/drive/MyDrive/ColabRun/CycleGAN_VidTIMIT01/weights:
total 105618
-rw------- 1 root root 17974272 Nov 29 12:52 combined-weights_1.h5
-rw------- 1 root root 17974272 Nov 29 13:08 combined-weights_3.h5
-rw------- 1 root root 17974272 Nov 29 13:08 combined-weights.h5
-rw------- 1 root root  2805136 Nov 29 12:52 d_A-weights_1.h5
-rw------- 1 root root  2805136 Nov 29 13:08 d_A-weights_3.h5
-rw------- 1 root root  2805136 Nov 29 13:08 d_A-weights.h5
-rw------- 1 root root  2805136 Nov 29 12:52 d_B-weights_1.h5
-rw------- 1 root root  2805136 Nov 29 13:08 d_B-weights_3.h5
-rw------- 1 root root  2805136 Nov 29 13:08 d_B-weights.h5
-rw------- 1 root root  6232880 Nov 29 12:52 g_AB-weights_1.h5
-rw------- 1 root root  6232880 Nov 29 13:08 g_AB-weights_3.h5
-rw------- 1 root root  6232880 Nov 29 13:08 g_AB-weights.h5
-rw------- 1 root root  6232880 Nov 29 12:52 g_BA-weights_1.h5
-rw------- 1 root root  6232880 Nov 29 13:08 g_BA-weights_3.h5
-rw------- 1 root root  6232880 Nov 29 13:08 g_BA-weights.h5

Load the saved file and try further training.

Load the saved parameters and model weights, and try training further.

セーブしたファイルをロードして、さらに学習を進める

保存してあるパラメータとモデルの重みをロードして、追加の学習を試みる。

In [43]:
# Load the saved parameters and weights.
# 保存してある学習結果をロードする。

gan_work = CycleGAN.load(save_path)

# Display the epoch count of the model.
# training のepoch回数を表示する。

print(gan_work.epoch)
3
In [44]:
# Training in addition
# 追加で training する。

gan_work.train(
    pair_flow,
    epochs=5,
    batch_size=1,
    run_folder = save_path,
    print_step_interval=500,
    save_epoch_interval = 5
)
WARNING:tensorflow:5 out of the last 38925 calls to <function Model.make_train_function.<locals>.train_function at 0x7f40e7dbfdd0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
Epoch 4/5 500/2595 [D loss: 0.145 acc: 0.812][G loss: 1.774 adv: 0.681 recon: 0.091 id: 0.090 time: 0:01:42.066682
Epoch 4/5 1000/2595 [D loss: 0.043 acc: 1.000][G loss: 2.676 adv: 1.659 recon: 0.085 id: 0.084 time: 0:03:13.842908
Epoch 4/5 1500/2595 [D loss: 0.092 acc: 0.920][G loss: 1.699 adv: 0.822 recon: 0.073 id: 0.072 time: 0:04:46.132772
Epoch 4/5 2000/2595 [D loss: 0.074 acc: 0.953][G loss: 2.001 adv: 0.781 recon: 0.102 id: 0.100 time: 0:06:18.087056
Epoch 4/5 2500/2595 [D loss: 0.073 acc: 0.957][G loss: 1.870 adv: 0.907 recon: 0.080 id: 0.081 time: 0:07:50.475077
Epoch 4/5 [D loss: 0.088 acc: 0.911][G loss: 1.905 adv: 0.908 recon: 0.084 id: 0.081 time: 0:08:07.897072
Epoch 5/5 500/2595 [D loss: 0.078 acc: 0.943][G loss: 2.253 adv: 1.262 recon: 0.082 id: 0.085 time: 0:09:40.228815
Epoch 5/5 1000/2595 [D loss: 0.071 acc: 0.944][G loss: 2.606 adv: 1.626 recon: 0.082 id: 0.082 time: 0:11:12.493290
Epoch 5/5 1500/2595 [D loss: 0.056 acc: 0.993][G loss: 1.197 adv: 0.355 recon: 0.070 id: 0.071 time: 0:12:44.713004
Epoch 5/5 2000/2595 [D loss: 0.087 acc: 0.896][G loss: 2.062 adv: 0.906 recon: 0.096 id: 0.096 time: 0:14:16.900000
Epoch 5/5 2500/2595 [D loss: 0.110 acc: 0.865][G loss: 1.430 adv: 0.466 recon: 0.080 id: 0.080 time: 0:15:49.408150
Epoch 5/5 [D loss: 0.081 acc: 0.922][G loss: 1.952 adv: 0.978 recon: 0.082 id: 0.079 time: 0:16:06.877118
In [45]:
# Display generated and cycle images.
# 生成画像とサイクル画像を表示する。

test_pairs = test_pair_flow[:5]

test_imgsA = test_pairs[:,0]
test_imgsB = test_pairs[:,1]

imgsAB = gan.generate_image_from_A(test_imgsA)
imgsBA = gan.generate_image_from_B(test_imgsB)

print('A-->B-->A, ID')
showImages(M1P1_ZeroP1(imgsAB), 4)

print('B-->A-->B, ID')
showImages(M1P1_ZeroP1(imgsBA), 4)
A-->B-->A, ID
B-->A-->B, ID
In [46]:
# Display the graph of losses in training
%matplotlib inline

gan_work.showLoss()
loss AB
loss BA
In [46]: