Updated 29/Nov/2021 by Yoshihisa Nitta  

Further Training of Cycle Generative Adversarial Network for VidTIMIT dataset with Tensorflow 2 on Google Colab (WGAN-GP)

Assuming that you have already executed CycleGAN_VidTIMIT_Train.ipynb, further train the Model.

VidTIMIT データセットに対して Cycle Generative Adversarial Network をGoogle Colab 上の Tensorflow 2 でさらに学習させる

既に CycleGAN_VidTIMIT_Train.ipynb を実行していることを前提とし、さらに学習を進める。

In [ ]:
MAX_EPOCHS = 50     # Change this value and run this ipynb many times

save_path = '/content/drive/MyDrive/ColabRun/CycleGAN_VidTIMIT01'
In [ ]:
#! pip install tensorflow==2.7.0
In [ ]:
! pip install tensorflow_addons
In [ ]:
%tensorflow_version 2.x

import tensorflow as tf
In [ ]:
import numpy as np


Check the Google Colab runtime environment

Google Colab 実行環境を調べる

In [ ]:
! nvidia-smi
    ! cat /proc/cpuinfo
    ! cat /etc/issue
    ! free -h
Mount Google Drive from Google Colab

Google Colab から GoogleDrive をマウントする

In [ ]:
from google.colab import drive
In [ ]:
    ! ls /content/drive

Download source file from Google Drive or nw.tsuda.ac.jp

Basically, gdown from Google Drive. Download from nw.tsuda.ac.jp above only if the specifications of Google Drive change and you cannot download from Google Drive.

Google Drive または nw.tsuda.ac.jp からファイルをダウンロードする

基本的に、Google Drive から gdown してください。 Google Drive の仕様が変わってダウンロードができない場合にのみ、nw.tsuda.ac.jp からダウンロードしてください。

In [ ]:
# Download source file
nw_path = './nw'
! rm -rf {nw_path}
! mkdir -p {nw_path}

if True:   # from Google Drive
    url_model =  'https://drive.google.com/uc?id=1aNvpPDNeDWYQFu_PA1kOtFlzcO5seHky'
    ! (cd {nw_path}; gdown {url_model})
else:      # from nw.tsuda.ac.jp
    URL_NW = 'https://nw.tsuda.ac.jp/lec/GoogleColab/pub'
    url_model = f'{URL_NW}/models/CycleGAN.py'
    ! wget -nd {url_model} -P {nw_path}
In [ ]:
    ! cat {nw_path}/CycleGAN.py
In [ ]:
# Download zip files
VidTIMIT_site = 'https://zenodo.org/record/158963/files/'
VidTIMIT_fnames = [ 'fadg0', 'faks0']

Mirrored_files = [

data_dir = './datasets'
! rm -rf $data_dir
! mkdir -p $data_dir

for i, fname in enumerate(VidTIMIT_fnames):
    fzip = fname + '.zip'
    if False:
        url = VidTIMIT_site + fzip
        !wget {url}
        url = Mirrored_files[i]
        !gdown {url}

    !unzip -q {fzip} -d {data_dir}
Make DataGenerator from the images of VidTIMIT

VidTIMIT の画像ファイルから DataGenerator を作る

In [ ]:
In [ ]:
import os
import glob

imgA_paths = glob.glob(os.path.join(data_dir, VidTIMIT_fnames[0], 'video/*/[0-9]*'))
imgB_paths = glob.glob(os.path.join(data_dir, VidTIMIT_fnames[1], 'video/*/[0-9]*'))
In [ ]:
import numpy as np

validation_split = 0.05

nA, nB = len(imgA_paths), len(imgB_paths)
splitA = int(nA * (1 - validation_split))
splitB = int(nB * (1 - validation_split))


train_imgA_paths = imgA_paths[:splitA]
test_imgA_paths = imgA_paths[splitA:]
train_imgB_paths = imgB_paths[:splitB]
test_imgB_paths = imgB_paths[splitB:]
In [ ]:
# Image: [-1, 1] --> [0, 1]
def M1P1_ZeroP1(imgs):
    imgs = (imgs + 1) * 0.5
    return np.clip(imgs, 0, 1)

# Image: [0, 1] --> [-1, 1]
def ZeroP1_M1P1(imgs):
    return imgs * 2 - 1
In [ ]:
from nw.CycleGAN import PairDataset

pair_flow = PairDataset(train_imgA_paths, train_imgB_paths, target_size=(IMAGE_SIZE, IMAGE_SIZE))
test_pair_flow = PairDataset(test_imgA_paths, test_imgB_paths, target_size=(IMAGE_SIZE, IMAGE_SIZE))

Define the Neural Network Model


In [ ]:
from nw.CycleGAN import CycleGAN

gan = CycleGAN.load(save_path)




Further Training


In [ ]:
    epochs = MAX_EPOCHS,
    run_folder = save_path,
    print_step_interval = 1000,
    save_epoch_interval = 50
In [ ]:
! ls {save_path}/weights
Generate Images


In [ ]:
# Display images
# 画像を表示する。
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

def showImages(imgs, rows=-1, cols=-1, w=2, h=2):
    N = len(imgs)
    if rows < 0: rows = 1
    if cols < 0: cols = (N + rows -1) // rows
    fig, ax = plt.subplots(rows, cols, figsize=(w*cols, h*rows))
    idx = 0
    for row in range(rows):
        for col in range(cols) :
            if rows == 1 and cols == 1:
                axis = ax
            elif rows == 1:
                axis = ax[col]
            elif cols == 1:
                axis = ax[row]
                axis = ax[row][col]

            if idx < N:
            idx += 1
In [ ]:
# Display generated and cycle images.
# 生成画像とサイクル画像を表示する。

test_pairs = test_pair_flow[:5]

test_imgsA = test_pairs[:,0]
test_imgsB = test_pairs[:,1]

imgsAB = gan.generate_image_from_A(test_imgsA)
imgsBA = gan.generate_image_from_B(test_imgsB)

print('A-->B-->A, ID')
showImages(M1P1_ZeroP1(imgsAB), 4)

print('B-->A-->B, ID')
showImages(M1P1_ZeroP1(imgsBA), 4)
A-->B-->A, ID
B-->A-->B, ID

Check the loss and accuracy of the training process.


In [ ]:
# Display the graph of losses in training
%matplotlib inline

loss AB
loss BA

Check the saved files


In [ ]:
! ls -lR {save_path}
In [ ]: