Updated 29/Nov/2021 by Yoshihisa Nitta  

Cycle Generative Adversarial Network for VidTIMIT dataset with Tensorflow 2 on Google Colab (WGAN-GP)

Train Cycle Generative Adversarial Network (CycleGAN) on VidTIMIT dataset.

VidTIMIT データセットに対して Cycle Generative Adversarial Network をGoogle Colab 上の Tensorflow 2 で学習する

VidTIMIT データセットに対して Cycle Generative Adversarial Network (CycleGAN) を学習する。

In [1]:
#! pip install tensorflow==2.7.0
In [2]:
! pip install tensorflow_addons
Requirement already satisfied: tensorflow_addons in /usr/local/lib/python3.7/dist-packages (0.15.0)
Requirement already satisfied: typeguard>=2.7 in /usr/local/lib/python3.7/dist-packages (from tensorflow_addons) (2.7.1)
In [3]:
%tensorflow_version 2.x

import tensorflow as tf
print(tf.__version__)
2.7.0
In [4]:
import numpy as np

np.random.seed(2022)

Check the Google Colab runtime environment

Google Colab 実行環境を調べる

In [5]:
! nvidia-smi
! cat /proc/cpuinfo
! cat /etc/issue
! free -h
Mon Nov 29 12:44:19 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 85
model name	: Intel(R) Xeon(R) CPU @ 2.00GHz
stepping	: 3
microcode	: 0x1
cpu MHz		: 2000.180
cache size	: 39424 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 0
initial apicid	: 0
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4000.36
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 bits virtual
power management:

processor	: 1
vendor_id	: GenuineIntel
cpu family	: 6
model		: 85
model name	: Intel(R) Xeon(R) CPU @ 2.00GHz
stepping	: 3
microcode	: 0x1
cpu MHz		: 2000.180
cache size	: 39424 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 1
initial apicid	: 1
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4000.36
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 bits virtual
power management:

Ubuntu 18.04.5 LTS \n \l

              total        used        free      shared  buff/cache   available
Mem:            12G        906M        2.0G        1.3M        9.8G         11G
Swap:            0B          0B          0B

Mount Google Drive from Google Colab

Google Colab から GoogleDrive をマウントする

In [6]:
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
In [7]:
! ls /content/drive
MyDrive  Shareddrives

Download source file from Google Drive or nw.tsuda.ac.jp

Basically, gdown from Google Drive. Download from nw.tsuda.ac.jp above only if the specifications of Google Drive change and you cannot download from Google Drive.

Google Drive または nw.tsuda.ac.jp からファイルをダウンロードする

基本的に、Google Drive から gdown してください。 Google Drive の仕様が変わってダウンロードができない場合にのみ、nw.tsuda.ac.jp からダウンロードしてください。

In [8]:
# Download source file
nw_path = './nw'
! rm -rf {nw_path}
! mkdir -p {nw_path}

if True:   # from Google Drive
    url_model =  'https://drive.google.com/uc?id=1aNvpPDNeDWYQFu_PA1kOtFlzcO5seHky'
    ! (cd {nw_path}; gdown {url_model})
else:      # from nw.tsuda.ac.jp
    URL_NW = 'https://nw.tsuda.ac.jp/lec/GoogleColab/pub'
    url_model = f'{URL_NW}/models/CycleGAN.py'
    ! wget -nd {url_model} -P {nw_path}
Downloading...
From: https://drive.google.com/uc?id=1aNvpPDNeDWYQFu_PA1kOtFlzcO5seHky
To: /content/nw/CycleGAN.py
100% 24.6k/24.6k [00:00<00:00, 37.7MB/s]
In [9]:
! cat {nw_path}/CycleGAN.py
import tensorflow as tf
import tensorflow_addons as tf_addons
import numpy as np

import matplotlib.pyplot as plt

from collections import deque

import os
import pickle as pkl
import random
import datetime


################################################################################
# Data Loader
################################################################################
class PairDataset():
    def __init__(self, paths_A, paths_B, batch_size= 1, target_size = None, unaligned=False):
        self.paths_A = np.array(paths_A)
        self.paths_B = np.array(paths_B)
        self.target_size = target_size
        self.batch_size = batch_size
        self.unaligned = unaligned

        self.lenA = len(paths_A)
        self.lenB = len(paths_B)
        self.index = 0

    def __len__(self):
        return max(self.lenA, self.lenB)

    def __getitem__(self, index):
        if isinstance(index, slice):
            start, stop, step = index.indices(self.__len__())
            if start == None: start = 0
            if stop == None: stop = self.__len__()
            if step == None:
                if start < stop:
                    step = 1
                elif start > stop:
                    step = -1
                else:
                    step = 0
            return np.array([self.__getitemInt__(i) for i in range(start, stop, step) ])
        else:
            return self.__getitemInt__(index)

    def __getitemInt__(self, index):
        path_A = self.paths_A[index % self.lenA]
        if self.unaligned:
            path_B = self.paths_B[np.random.choice(self.lenB, 1)]
        else:
            path_B = self.paths_B[index % self.lenB]
        img_A = np.array(tf.keras.utils.load_img(path_A, target_size = self.target_size))
        img_B = np.array(tf.keras.utils.load_img(path_B, target_size = self.target_size))
        img_A = (img_A.astype('float32') - 127.5) / 127.5
        img_B = (img_B.astype('float32') - 127.5) / 127.5
        return np.array([img_A, img_B])

    def __next__(self):
        self.index += 1
        return self.__getitem__(self.index-1)



################################################################################
# Layer
################################################################################
class ReflectionPadding2D(tf.keras.layers.Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [tf.keras.layers.InputSpec(ndim=4)]
        super().__init__(**kwargs)

    def compute_output_shape(self, s):
        '''
        If you are using "channels_last" configuration
        '''
        return (s[0], s[1]+2*self.padding[0], s[2]+2*self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad, h_pad = self.padding
        return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')


################################################################################
# Model
################################################################################
class CycleGAN():
    def __init__(
        self,
        input_dim,
        learning_rate,
        lambda_validation,
        lambda_reconstr,
        lambda_id,
        generator_type,
        gen_n_filters,
        disc_n_filters,
        buffer_max_length = 50,
        epoch = 0, 
        d_losses = [],
        g_losses = []
    ):
        self.input_dim = input_dim
        self.learning_rate = learning_rate
        self.buffer_max_length = buffer_max_length
        self.lambda_validation = lambda_validation
        self.lambda_reconstr = lambda_reconstr
        self.lambda_id = lambda_id
        self.generator_type = generator_type
        self.gen_n_filters = gen_n_filters
        self.disc_n_filters = disc_n_filters

        # Input shape
        self.img_rows = input_dim[0]
        self.img_cols = input_dim[1]
        self.channels = input_dim[2]
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
        self.epoch = epoch
        self.d_losses = d_losses
        self.g_losses = g_losses

        self.buffer_A = deque(maxlen=self.buffer_max_length)
        self.buffer_B = deque(maxlen=self.buffer_max_length)

        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**3)
        self.disc_patch = (patch, patch, 1)
        
        self.weight_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
        
        self.compile_models()


    def compile_models(self):
        # Build and compile the discriminators
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()

        self.d_A.compile(
            loss='mse',
            optimizer=tf.keras.optimizers.Adam(self.learning_rate, 0.5),
            metrics=['accuracy']
        )
        self.d_B.compile(
            loss='mse',
            optimizer=tf.keras.optimizers.Adam(self.learning_rate, 0.5),
            metrics=['accuracy']
        )

        # Build the generators
        if self.generator_type == 'unet':
            self.g_AB = self.build_generator_unet()
            self.g_BA = self.build_generator_unet()
        else:
            self.g_AB = self.build_generator_resnet()
            self.g_BA = self.build_generator_resnet()

        # For the combined model we will only train the generators
        self.d_A.trainable = False
        self.d_B.trainable = False

        # Input images from both domains
        img_A = tf.keras.layers.Input(shape=self.img_shape)
        img_B = tf.keras.layers.Input(shape=self.img_shape)

        # Translate images to the other domain
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)

        # translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)

        # Identity mapping of images
        img_A_id = self.g_BA(img_A)   # [self memo] ??? translate *A* from domainB to domainA 
        img_B_id = self.g_AB(img_B)   # [self memo] ??? translate *B* from domainA to domainB 
        
        # Discriminators determines validity of traslated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        # Combined model trains generators to fool discriminators
        self.combined = tf.keras.models.Model(
            inputs=[img_A, img_B],
            outputs=[valid_A, valid_B, reconstr_A, reconstr_B, img_A_id, img_B_id]
        )
        self.combined.compile(
            loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'],  # Mean Squared Error, Mean Absolute Error
            loss_weights=[ self.lambda_validation, self.lambda_validation,
                         self.lambda_reconstr, self.lambda_reconstr,
                         self.lambda_id, self.lambda_id ],
            optimizer=tf.keras.optimizers.Adam(0.0002, 0.5)
        )
        self.d_A.trainable = True
        self.d_B.trainable = True


    def build_generator_unet(self):
        def downsample(layer_input, filters, f_size=4):
            d = tf.keras.layers.Conv2D(
                filters,
                kernel_size=f_size,
                strides=2,
                padding='same',
                kernel_initializer = self.weight_init  # [self memo] added by nitta
            )(layer_input)
            d = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(d)
            d = tf.keras.layers.Activation('relu')(d)
            return d
        def upsample(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            u = tf.keras.layers.UpSampling2D(size=2)(layer_input)
            u = tf.keras.layers.Conv2D(
                filters, 
                kernel_size=f_size, 
                strides=1, 
                padding='same',
                kernel_initializer = self.weight_init  # [self memo] added by nitta
            )(u)
            u = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(u)
            u = tf.keras.layers.Activation('relu')(u)
            if dropout_rate:
                u = tf.keras.layers.Dropout(dropout_rate)(u)
            u = tf.keras.layers.Concatenate()([u, skip_input])
            return u
        # Image input
        img = tf.keras.layers.Input(shape=self.img_shape)
        # Downsampling
        d1 = downsample(img, self.gen_n_filters)
        d2 = downsample(d1, self.gen_n_filters*2)
        d3 = downsample(d2, self.gen_n_filters*4)
        d4 = downsample(d3, self.gen_n_filters*8)

        # Upsampling
        u1 = upsample(d4, d3, self.gen_n_filters*4)
        u2 = upsample(u1, d2, self.gen_n_filters*2)
        u3 = upsample(u2, d1, self.gen_n_filters)

        u4 = tf.keras.layers.UpSampling2D(size=2)(u3)
        output_img = tf.keras.layers.Conv2D(
            self.channels, 
            kernel_size=4,
            strides=1, 
            padding='same',
            activation='tanh',
            kernel_initializer = self.weight_init  # [self memo] added by nitta
        )(u4)

        return tf.keras.models.Model(img, output_img)


    def build_generator_resnet(self):
        def conv7s1(layer_input, filters, final):
            y = ReflectionPadding2D(padding=(3,3))(layer_input)
            y = tf.keras.layers.Conv2D(
                filters,
                kernel_size=(7,7),
                strides=1,
                padding='valid',
                kernel_initializer=self.weight_init
            )(y)
            if final:
                y = tf.keras.layers.Activation('tanh')(y)
            else:
                y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
                y = tf.keras.layers.Activation('relu')(y)
            return y

        def downsample(layer_input, filters):
            y = tf.keras.layers.Conv2D(
                filters, 
                kernel_size=(3,3), 
                strides=2, 
                padding='same',
                kernel_initializer = self.weight_init
            )(layer_input)
            y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            y = tf.keras.layers.Activation('relu')(y)
            return y

        def residual(layer_input, filters):
            shortcut = layer_input
            y = ReflectionPadding2D(padding=(1,1))(layer_input)
            y = tf.keras.layers.Conv2D(
                filters,
                kernel_size=(3,3),
                strides=1,
                padding='valid',
                kernel_initializer=self.weight_init
            )(y)
            y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            y = tf.keras.layers.Activation('relu')(y)
            y = ReflectionPadding2D(padding=(1,1))(y)
            y = tf.keras.layers.Conv2D(
                filters,
                kernel_size=(3,3),
                strides=1,
                padding='valid',
                kernel_initializer=self.weight_init
            )(y)
            y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            return tf.keras.layers.add([shortcut, y])
          
        def upsample(layer_input, filters):
            y = tf.keras.layers.Conv2DTranspose(
                filters, 
                kernel_size=(3,3), 
                strides=2,
                padding='same',
                kernel_initializer=self.weight_init
            )(layer_input)
            y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            y = tf.keras.layers.Activation('relu')(y)
            return y

        # Image input
        img = tf.keras.layers.Input(shape=self.img_shape)

        y = img
        y = conv7s1(y, self.gen_n_filters, False)
        y = downsample(y, self.gen_n_filters * 2)
        y = downsample(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = upsample(y, self.gen_n_filters * 2)
        y = upsample(y, self.gen_n_filters)
        y = conv7s1(y, 3, True)
        output = y
        
        return tf.keras.models.Model(img, output)


    def build_discriminator(self):
        def conv4(layer_input, filters, stride=2, norm=True):
            y = tf.keras.layers.Conv2D(
                filters,
                kernel_size=(4,4),
                strides=stride,
                padding='same',
                kernel_initializer = self.weight_init
              )(layer_input)
            if norm:
                y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            y = tf.keras.layers.LeakyReLU(0.2)(y)
            return y

        img = tf.keras.layers.Input(shape=self.img_shape)
        y = conv4(img, self.disc_n_filters, stride=2, norm=False)
        y = conv4(y, self.disc_n_filters*2, stride=2)
        y = conv4(y, self.disc_n_filters*4, stride=2)
        y = conv4(y, self.disc_n_filters*8, stride=1)
        output = tf.keras.layers.Conv2D(
            1,
            kernel_size=4,
            strides=1,
            padding='same',
            kernel_initializer=self.weight_init
        )(y)
        return tf.keras.models.Model(img, output)


    def train_discriminators(self, imgs_A, imgs_B, valid, fake):
        # Translate images to opposite domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        
        self.buffer_B.append(fake_B)
        self.buffer_A.append(fake_A)

        fake_A_rnd = random.sample(self.buffer_A, min(len(self.buffer_A), len(imgs_A))) # random sampling without replacement 
        fake_B_rnd = random.sample(self.buffer_B, min(len(self.buffer_B), len(imgs_B)))
        
        # Train the discriminators (original images=real / translated = fake)
        dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
        dA_loss_fake = self.d_A.train_on_batch(fake_A_rnd, fake)
        dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

        dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
        dB_loss_fake = self.d_B.train_on_batch(fake_B_rnd, fake)
        dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

        # Total discriminator loss
        d_loss_total = 0.5 * np.add(dA_loss, dB_loss)

        return (
            d_loss_total[0], 
            dA_loss[0], dA_loss_real[0], dA_loss_fake[0],
            dB_loss[0], dB_loss_real[0], dB_loss_fake[0],
            d_loss_total[1], 
            dA_loss[1], dA_loss_real[1], dA_loss_fake[1],
            dB_loss[1], dB_loss_real[1], dB_loss_fake[1]
        )


    def train_generators(self, imgs_A, imgs_B, valid):
        # Train the generators
        return self.combined.train_on_batch(
          [imgs_A, imgs_B], 
          [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B]
        )


    def train(self, data_loader, epochs, batch_size=1, run_folder='./run', print_step_interval=100, save_epoch_interval=100):
        start_time = datetime.datetime.now()
        # Adversarial loss ground truthes
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)

        steps = len(data_loader) // batch_size
        for epoch in range(self.epoch, epochs):
            step_d_losses = []
            step_g_losses = []
            for step in range(steps):
                start = step * batch_size
                end = start + batch_size
                pairs = data_loader[start:end]    # ((a,b), (a, b), ....)
                imgs_A, imgs_B = [], []
                for img_A, img_B in pairs:
                    imgs_A.append(img_A)
                    imgs_B.append(img_B)

                imgs_A = np.array(imgs_A)
                imgs_B = np.array(imgs_B)

                step_d_loss = self.train_discriminators(imgs_A, imgs_B, valid, fake)
                step_g_loss = self.train_generators(imgs_A, imgs_B, valid)

                step_d_losses.append(step_d_loss)
                step_g_losses.append(step_g_loss)

                elapsed_time = datetime.datetime.now() - start_time
                if (step+1) % print_step_interval == 0:
                    print(f'Epoch {epoch+1}/{epochs} {step+1}/{steps} [D loss: {step_d_loss[0]:.3f} acc: {step_d_loss[7]:.3f}][G loss: {step_g_loss[0]:.3f} adv: {np.sum(step_g_loss[1:3]):.3f} recon: {np.sum(step_g_loss[3:5]):.3f} id: {np.sum(step_g_loss[5:7]):.3f} time: {elapsed_time:}')

            d_loss = np.mean(step_d_losses, axis=0)
            g_loss = np.mean(step_g_losses, axis=0)

            elapsed_time = datetime.datetime.now() - start_time

            elapsed_time = datetime.datetime.now() - start_time
            print(f'Epoch {epoch+1}/{epochs} [D loss: {d_loss[0]:.3f} acc: {d_loss[7]:.3f}][G loss: {g_loss[0]:.3f} adv: {np.sum(g_loss[1:3]):.3f} recon: {np.sum(g_loss[3:5]):.3f} id: {np.sum(g_loss[5:7]):.3f} time: {elapsed_time:}')
                    
            self.d_losses.append(d_loss)
            self.g_losses.append(g_loss)

            self.epoch += 1
            if (self.epoch) % save_epoch_interval == 0:
                self.save(run_folder, self.epoch)
                self.save(run_folder)

        self.save(run_folder, self.epoch)
        self.save(run_folder)


    def save(self, folder, epoch=None):
        self.save_params(folder, epoch)
        self.save_weights(folder,epoch)


    @staticmethod
    def load(folder, epoch=None):
        params = CycleGAN.load_params(folder, epoch)
        gan = CycleGAN(*params)
        gan.load_weights(folder, epoch)
        return gan


    def save_weights(self, run_folder, epoch=None):
        if epoch is None:
            self.save_model_weights(self.combined, os.path.join(run_folder, 'weights/combined-weights.h5'))
            self.save_model_weights(self.d_A, os.path.join(run_folder, 'weights/d_A-weights.h5'))
            self.save_model_weights(self.d_B, os.path.join(run_folder, 'weights/d_B-weights.h5'))
            self.save_model_weights(self.g_AB, os.path.join(run_folder, 'weights/g_AB-weights.h5'))
            self.save_model_weights(self.g_BA, os.path.join(run_folder, 'weights/g_BA-weights.h5'))
        else:
            self.save_model_weights(self.combined, os.path.join(run_folder, f'weights/combined-weights_{epoch}.h5'))
            self.save_model_weights(self.d_A, os.path.join(run_folder, f'weights/d_A-weights_{epoch}.h5'))
            self.save_model_weights(self.d_B, os.path.join(run_folder, f'weights/d_B-weights_{epoch}.h5'))
            self.save_model_weights(self.g_AB, os.path.join(run_folder, f'weights/g_AB-weights_{epoch}.h5'))
            self.save_model_weights(self.g_BA, os.path.join(run_folder, f'weights/g_BA-weights_{epoch}.h5'))


    def load_weights(self, run_folder, epoch=None):
        if epoch is None:
            self.load_model_weights(self.combined, os.path.join(run_folder, 'weights/combined-weights.h5'))
            self.load_model_weights(self.d_A, os.path.join(run_folder, 'weights/d_A-weights.h5'))
            self.load_model_weights(self.d_B, os.path.join(run_folder, 'weights/d_B-weights.h5'))
            self.load_model_weights(self.g_AB, os.path.join(run_folder, 'weights/g_AB-weights.h5'))
            self.load_model_weights(self.g_BA, os.path.join(run_folder, 'weights/g_BA-weights.h5'))
        else:
            self.load_model_weights(self.combined, os.path.join(run_folder, f'weights/combined-weights_{epoch}.h5'))
            self.load_model_weights(self.d_A, os.path.join(run_folder, f'weights/d_A-weights_{epoch}.h5'))
            self.load_model_weights(self.d_B, os.path.join(run_folder, f'weights/d_B-weights_{epoch}.h5'))
            self.load_model_weights(self.g_AB, os.path.join(run_folder, f'weights/g_AB-weights_{epoch}.h5'))
            self.load_model_weights(self.g_BA, os.path.join(run_folder, f'weights/g_BA-weights_{epoch}.h5'))


    def save_model_weights(self, model, filepath):
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        model.save_weights(filepath)


    def load_model_weights(self, model, filepath):
        model.load_weights(filepath)


    def save_params(self, folder, epoch=None):
        if epoch is None:
            filepath = os.path.join(folder, 'params.pkl')
        else:
            filepath = os.path.join(folder, f'params_{epoch}.pkl')

        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)

        with open(filepath, 'wb') as f:
            pkl.dump([
                self.input_dim,
                self.learning_rate,
                self.lambda_validation,
                self.lambda_reconstr,
                self.lambda_id,
                self.generator_type,
                self.gen_n_filters,
                self.disc_n_filters,
                self.buffer_max_length,
                self.epoch,
                self.d_losses,
                self.g_losses
              ], f)


    @staticmethod
    def load_params(folder, epoch=None):
        if epoch is None:
            filepath = os.path.join(folder, 'params.pkl')
        else:
            filepath = os.path.join(folder, f'params_{epoch}.pkl')

        with open(filepath, 'rb') as f:
            params = pkl.load(f)
        return params


    def generate_image(self, img_A, img_B):
        gen_A = self.generate_image_from_A(img_A)
        gen_B = self.generate_image_from_B(img_B)
        return np.concatenate([gen_A, gen_B], axis=0)


    def generate_image_from_A(self, img_A):
        fake_B = self.g_AB.predict(img_A)      # Translate images to the other domain
        reconstr_A = self.g_BA.predict(fake_B)  # Translate back to original domain
        id_A = self.g_BA.predict(img_A)    # ID the images
        return np.concatenate([img_A, fake_B, reconstr_A, id_A])


    def generate_image_from_B(self, img_B):
        fake_A = self.g_BA.predict(img_B)
        reconstr_B = self.g_AB.predict(fake_A)
        id_B = self.g_AB.predict(img_B)
        return np.concatenate([img_B, fake_A, reconstr_B, id_B])


    @staticmethod
    def showImages(imgs, trans, recon, idimg, w=2.8, h=2.8, filepath=None):
        N = len(imgs)
        M = len(imgs[0])
        titles = ['Original', 'Translated', 'Reconstructed', 'ID']

        fig, ax = plt.subplots(N, M, figsize=(w*M, h*N))
        for i in range(N):
            for j in range(M):
                ax[i][j].imshow(imgs[i][j])
                ax[i][j].set_title(title[j])
                ax[i][j].axis('off')

        if not filepath is None:
            dpath, fname = os.path.split(filepath)
            if dpath != '' and not os.path.exists(dpath):
                os.makedirs(dpath)
            fig.savefig(filepath, dpi=600)
            plt.close()
        else:
            plt.show()
        

    def showLoss(self, xlim=[], ylim=[]):
        print('loss AB')
        self.showLossAB(xlim, ylim)
        print('loss BA')
        self.showLossBA(xlim, ylim)


    def showLossAB(self, xlim=[], ylim=[]):
        g = np.array(self.g_losses)
        g_loss = g[:, 0]
        g_adv = g[:, 1]
        g_recon = g[:, 3]
        g_id = g[:, 5]
        CycleGAN.plot_history(
            [g_loss, g_adv, g_recon, g_id],
            ['g_loss', 'AB discrim', 'AB cycle', 'AB id'],
            xlim,
            ylim)

    def showLossBA(self, xlim=[], ylim=[]):
        g = np.array(self.g_losses)
        g_loss = g[:, 0]
        g_adv = g[:, 2]
        g_recon = g[:, 4]
        g_id = g[:, 6]
        CycleGAN.plot_history(
            [g_loss, g_adv, g_recon, g_id],
            ['g_loss', 'BA discrim', 'BA cycle', 'BA id'],
            xlim,
            ylim)


    @staticmethod
    def plot_history(vals, labels, xlim=[], ylim=[]):
        colors = ['red', 'blue', 'green', 'orange', 'black', 'pink']
        n = len(vals)
        fig, ax = plt.subplots(1, 1, figsize=(12,6))
        for i in range(n):
            ax.plot(vals[i], c=colors[i], label=labels[i])
        ax.legend(loc='upper right')
        ax.set_xlabel('epochs')
        # ax.set_ylabel('loss')

        if xlim != []:
            ax.set_xlim(xlim[0], xlim[1])
        if ylim != []:
            ax.set_ylim(ylim[0], ylim[1])
        
        plt.show()

Preparing VidTIMIT dataset

Official WWW of VidTIMIT dataset: http://conradsanderson.id.au/vidtimit/

zip files of 2 persons of VidTIMIT dataset:
https://zenodo.org/record/158963/files/fadg0.zip
https://zenodo.org/record/158963/files/faks0.zip

zip files mirrored on my Google Drive:
https://drive.google.com/uc?id=

VidTIMIT データセットを用意する

VidTIMIT データセットの公式ページ: http://conradsanderson.id.au/vidtimit/

VidTIMIT の2名の顔写真の zip ファイル:
https://zenodo.org/record/158963/files/fadg0.zip
https://zenodo.org/record/158963/files/faks0.zip

自分の Google Drive 上にミラーした顔写真:
https://drive.google.com/uc?id= https://drive.google.com/uc?id=

In [10]:
# Download zip files
VidTIMIT_site = 'https://zenodo.org/record/158963/files/'
VidTIMIT_fnames = [ 'fadg0', 'faks0']

Mirrored_files = [
    'https://drive.google.com/uc?id=1_Fv4p9MDNphMZMnLpEvtCtnwXgN8N5Cj', 
    'https://drive.google.com/uc?id=1Y8j7ThPVqB0gbx4hb9aMEp9Ptr9wFuoz'
]

data_dir = './datasets'
! rm -rf $data_dir
! mkdir -p $data_dir

for i, fname in enumerate(VidTIMIT_fnames):
    fzip = fname + '.zip'
    if False:
        url = VidTIMIT_site + fzip
        !wget {url}
    else:
        url = Mirrored_files[i]
        !gdown {url}

    !unzip -q {fzip} -d {data_dir}
Downloading...
From: https://drive.google.com/uc?id=1_Fv4p9MDNphMZMnLpEvtCtnwXgN8N5Cj
To: /content/fadg0.zip
100% 81.6M/81.6M [00:00<00:00, 219MB/s]
Downloading...
From: https://drive.google.com/uc?id=1Y8j7ThPVqB0gbx4hb9aMEp9Ptr9wFuoz
To: /content/faks0.zip
100% 64.2M/64.2M [00:00<00:00, 174MB/s]
In [11]:
! ls {data_dir}
fadg0  faks0
In [12]:
! ls {data_dir}/fadg0
audio  video
In [13]:
! ls {data_dir}/fadg0/video
head   head3  sa2     si1909  sx109  sx199  sx379
head2  sa1    si1279  si649   sx19   sx289
In [14]:
! ls {data_dir}/fadg0/video/head | head -5
001
002
003
004
005
In [15]:
! ls {data_dir}/fadg0/video/sx19 | head -5
001
002
003
004
005

Check the images of VidTIMIT

VidTIMIT の画像をチェックする

In [16]:
# Display images
# 画像を表示する。
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

def showImages(imgs, rows=-1, cols=-1, w=2, h=2):
    N = len(imgs)
    if rows < 0: rows = 1
    if cols < 0: cols = (N + rows -1) // rows
    fig, ax = plt.subplots(rows, cols, figsize=(w*cols, h*rows))
    idx = 0
    for row in range(rows):
        for col in range(cols) :
            if rows == 1 and cols == 1:
                axis = ax
            elif rows == 1:
                axis = ax[col]
            elif cols == 1:
                axis = ax[row]
            else:
                axis = ax[row][col]

            if idx < N:
                axis.imshow(imgs[idx])
            axis.axis('off')
            idx += 1
    plt.show()
In [17]:
# Display images with the file paths.
# 画像のpathの配列を受け取って、画像を表示する。
%matplotlib inline
import tensorflow as tf

def showImagesByPath(fnames,rows=-1, cols=-1, w=2, h=2):
    imgs = [ tf.keras.utils.load_img(fname) for fname in fnames]
    showImages(imgs, rows, cols, w, h)
In [18]:
import os
import glob

! ls {data_dir}/{VidTIMIT_fnames[0]}/video
head   head3  sa2     si1909  sx109  sx199  sx379
head2  sa1    si1279  si649   sx19   sx289
In [19]:
! ls {data_dir}
fadg0  faks0
In [20]:
# Images of A
person = VidTIMIT_fnames[0]

print(f'===={person}====')
paths = glob.glob(os.path.join(f'{data_dir}/{person}/video', '*'))
for path in paths:
    dpath, dir = os.path.split(path)
    fpath = glob.glob(os.path.join(path, '*'))
    print(f'{dir}  {len(fpath)}')
    showImagesByPath(fpath[:5],-1,-1,1,1)
====fadg0====
sx289  126
sx19  97
si1279  72
si1909  117
sa1  119
sx109  135
sx199  144
head2  406
head3  742
si649  216
sx379  109
sa2  103
head  346
In [21]:
# Images of B
person = VidTIMIT_fnames[1]

print(f'===={person}====')
paths = glob.glob(os.path.join(f'{data_dir}/{person}/video', '*'))
for path in paths:
    dpath, dir = os.path.split(path)
    fpath = glob.glob(os.path.join(path, '*'))
    print(f'{dir}  {len(fpath)}')
    showImagesByPath(fpath[:5],-1,-1,1,1)
====faks0====
sx133  97
sa1  89
sx313  90
head2  358
si2203  81
head3  369
sx403  97
sa2  81
sx223  89
si1573  145
head  458
si943  98
sx43  86

Make DataGenerator from the images of VidTIMIT

VidTIMIT の画像ファイルから DataGenerator を作る

In [22]:
IMAGE_SIZE = 128
In [23]:
import os
import glob

imgA_paths = glob.glob(os.path.join(data_dir, VidTIMIT_fnames[0], 'video/*/[0-9]*'))
imgB_paths = glob.glob(os.path.join(data_dir, VidTIMIT_fnames[1], 'video/*/[0-9]*'))
In [24]:
import numpy as np

validation_split = 0.05

nA, nB = len(imgA_paths), len(imgB_paths)
splitA = int(nA * (1 - validation_split))
splitB = int(nB * (1 - validation_split))

np.random.shuffle(imgA_paths)
np.random.shuffle(imgB_paths)

train_imgA_paths = imgA_paths[:splitA]
test_imgA_paths = imgA_paths[splitA:]
train_imgB_paths = imgB_paths[:splitB]
test_imgB_paths = imgB_paths[splitB:]
In [25]:
# Image: [-1, 1] --> [0, 1]
def M1P1_ZeroP1(imgs):
    imgs = (imgs + 1) * 0.5
    return np.clip(imgs, 0, 1)

# Image: [0, 1] --> [-1, 1]
def ZeroP1_M1P1(imgs):
    return imgs * 2 - 1
In [26]:
from nw.CycleGAN import PairDataset

pair_flow = PairDataset(train_imgA_paths, train_imgB_paths, target_size=(IMAGE_SIZE, IMAGE_SIZE))
test_pair_flow = PairDataset(test_imgA_paths, test_imgB_paths, target_size=(IMAGE_SIZE, IMAGE_SIZE))
In [27]:
a, b = next(pair_flow)
print(a.shape, b.shape)
showImages([M1P1_ZeroP1(a), M1P1_ZeroP1(b)])
(128, 128, 3) (128, 128, 3)
In [28]:
pairs = pair_flow[10:15]
shape = pairs.shape
print(shape)

a = pairs.reshape(-1, *shape[2:])
print(a.shape)
showImages(M1P1_ZeroP1(a), 5, 2)
(5, 2, 128, 128, 3)
(10, 128, 128, 3)