5章 描く

スタイル変換 ... 入力となるベース画像を、与えられたスタイル画像と同じグループのものであるように変換する。 スタイル画像からスタイルの要素だけを抽出して、それをベースの画像に埋め込む。

この章では「CycleGAN とニューラルスタイル変換」の作成方法を学ぶ。

5.1 リンゴとオレンジ

5.2 CycleGAN

CycleGAN (cycle-consistent adversarial network)

Jun-Yan Zhu et al. "Unpaired Image-to-Image Translation Using Cycle-Consistent Adversarial Network", 30/March/2017 https://arxiv.org/pdf/1703.10593

原論文は、2枚の画像の訓練セットなしで、参照する画像セットのスタイルを別の画像に転写できるようモデルを訓練することが可能であることを示した。

5.3 初めての CycleGAN

この本の CycleGAN のコードは Erik Linder-Noren の Keras-GAN ripository http://bit.ly/2Za68J2 を参考にしているとのこと。

https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/ から apple2orange.zip をダウンロードする。

G:\マイドライブ\DeepLearning\book14\local\GDL_code\data\apple2orange に展開した。

    GDL_code/data/apple2orange
      +--trainA/*.jpg    ... リンゴの画像
      +--testA/*.jpg
      +--trainB/*.jpg    ... オレンジの画像
      +--testB/*.jpg

5.3.1 概要

CycleGAN は4つのモデルから成り立っている。 2つのGeneratorと2つの Discriminator を使う。

  Generator
    g_AB ... 領域Aの画像を領域Bに変換する
    g_BA ... 領域Bの画像を領域Aに変換する

Generatorを訓練するためのペア画像は存在しないので、Generator の出力を判定する Discriminator を訓練する必要がある。

  Discriminator
    d_A ... 領域Aの本物画像と、Generator g_BA が生成した偽画像を識別するように訓練する
    d_B ... 領域Bの本物画像と、Generator g_AB が生成した偽画像を識別するように訓練する
In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
#DATA_PATH = 'g:/マイドライブ/DeepLearning/book14/local/GDL_code/data/'
DATA_PATH = 'D:/Users/nitta/Documents/book14/local/GDL_code/data'
In [3]:
DATA_NAME = 'apple2orange'
IMAGE_SIZE = 128
In [4]:
# [自分へのメモ] 元のコードは、ipynb ファイルの直下の'./data/'からデータをロードするコードだが、
# path を指定できるように変更する。

import tensorflow as tf
import glob
import PIL
import imageio
import numpy as np

class DataLoader():
    def __init__(self, data_path, dataset_name, img_res=(256,256)): # data_path added
        self.data_path = data_path
        self.dataset_name = dataset_name
        self.img_res = img_res
        
    def load_data(self, domain, batch_size=1, is_testing=False):  # domain='A' or 'B'
        data_type = f'train{domain}' if not is_testing else f'test{domain}'
        path = glob.glob(f'{self.data_path}/{self.dataset_name}/{data_type}/*')
        
        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_images:
            img = self.imread(img_path)
            img = np.array(PIL.Image.fromarray(img).resize(self.img_res))
            if not is_testing:
                if np.random.random() > 0.5:
                    img = np.fliplr(img)  # Reverse the order of elements along axis 1 (left/right)

            imgs.append(img)

        imgs = np.array(imgs)/127.5 - 1.0

        return imgs

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = 'train' if not is_testing else 'val'
        path_A = glob.glob(f'{self.data_path}/{self.dataset_name}/{data_type}A/*')
        path_B = glob.glob(f'{self.data_path}/{self.dataset_name}/{data_type}B/*')

        self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = self.n_batches * batch_size

        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)

        for i in range(self.n_batches):  # [self memo] ???BUG??? original source is self.n_batches-1
            batch_A = path_A[i*batch_size: (i+1)*batch_size]
            batch_B = path_B[i*batch_size: (i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = self.imread(img_A)
                img_B = self.imread(img_B)

                img_A = np.array(PIL.Image.fromarray(img_A).resize(self.img_res))
                img_B = np.array(PIL.Image.fromarray(img_B).resize(self.img_res))

                if not is_testing and np.random.random() > 0.5:
                    img_A = np.fliplr(img_A)
                    img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A)/127.5 - 1.0
            imgs_B = np.array(imgs_B)/127.5 - 1.0

            yield imgs_A, imgs_B

    def load_img(self, path):
        img = self.imread(path)
        img = np.array(PIL.Image.fromarray(img).resize(self.img_res))
        img = img / 127.5 - 1.0
        return img[np.newaxis, :, :, :]  # add new dimensions to ndarray (newaxis == None)

    def imread(self, path):
        return imageio.imread(path, pilmode='RGB').astype(np.uint8)
In [5]:
data_loader = DataLoader(data_path=DATA_PATH, dataset_name=DATA_NAME, img_res=(IMAGE_SIZE, IMAGE_SIZE))
In [6]:
# [自習] 正しくデータが読み込めていることを確認する。


data_gen = DataLoader(data_path=DATA_PATH, dataset_name=DATA_NAME, img_res=(IMAGE_SIZE, IMAGE_SIZE)).load_batch()
img_A, img_B = next(data_gen)

data_gen2 =  DataLoader(data_path=DATA_PATH, dataset_name=DATA_NAME, img_res=(IMAGE_SIZE, IMAGE_SIZE)).load_batch(
    #is_testing=True
)
img_A_test, img_B_test = next(data_gen2)

imgs = np.concatenate([img_A, img_B, img_A_test, img_B_test])
n = len(imgs)

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots(1,n,figsize=(IMAGE_SIZE/10, IMAGE_SIZE*n/10))

for i in range(n):
    ax[i].imshow(imgs[i])
    ax[i].axis('off')
    
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Model

元のソースコードでは CycricGAN.train() の中で頻繁に進捗状況を表示し、かつモデルを保存していた ので実行速度やメモリ使用料に非常に悪影響があった。 これを epoch 毎に、進捗状況やモデルを保存するように変更した。

In [7]:
# GDL_code/models/layers.py

import tensorflow as tf

class ReflectionPadding2D(tf.keras.layers.Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [tf.keras.layers.InputSpec(ndim=4)]
        super().__init__(**kwargs)

    def compute_output_shape(self, s):
        '''
        If you are using "channels_last" configuration
        '''
        return (s[0], s[1]+2*self.padding[0], s[2]+2*self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad, h_pad = self.padding
        return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')
In [8]:
# GDL_code/models/CycleGAN.py

import tensorflow as tf
import tensorflow_addons as tf_addons
import numpy as np
import os
import pickle as pkl
import random
import datetime

from collections import deque

class CycleGAN():
    def __init__(
        self,
        input_dim,
        learning_rate,
        lambda_validation,
        lambda_reconstr,
        lambda_id,
        generator_type,
        gen_n_filters,
        disc_n_filters,
        buffer_max_length = 50
    ):
        self.input_dim = input_dim
        self.learning_rate = learning_rate
        self.buffer_max_length = buffer_max_length
        self.lambda_validation = lambda_validation
        self.lambda_reconstr = lambda_reconstr
        self.lambda_id = lambda_id
        self.generator_type = generator_type
        self.gen_n_filters = gen_n_filters
        self.disc_n_filters = disc_n_filters

        # Input shape
        self.img_rows = input_dim[0]
        self.img_cols = input_dim[1]
        self.channels = input_dim[2]
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
        self.d_losses = []
        self.g_losses = []
        self.epoch = 0

        self.buffer_A = deque(maxlen=self.buffer_max_length)
        self.buffer_B = deque(maxlen=self.buffer_max_length)

        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**3)
        self.disc_patch = (patch, patch, 1)
        
        self.weight_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
        
        self.compile_models()

    def compile_models(self):
        # Build and compile the discriminators
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()

        self.d_A.compile(
            loss='mse',
            optimizer=tf.keras.optimizers.Adam(self.learning_rate, 0.5),
            metrics=['accuracy']
        )
        self.d_B.compile(
            loss='mse',
            optimizer=tf.keras.optimizers.Adam(self.learning_rate, 0.5),
            metrics=['accuracy']
        )

        # Build the generators
        if self.generator_type == 'unet':
            self.g_AB = self.build_generator_unet()
            self.g_BA = self.build_generator_unet()
        else:
            self.g_AB = self.build_generator_resnet()
            self.g_BA = self.build_generator_resnet()

        # For the combined model we will only train the generators
        self.d_A.trainable = False
        self.d_B.trainable = False

        # Input images from both domains
        img_A = tf.keras.layers.Input(shape=self.img_shape)
        img_B = tf.keras.layers.Input(shape=self.img_shape)

        # Translate images to the other domain
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)

        # translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)

        # Identity mapping of images
        img_A_id = self.g_BA(img_A)   # [self memo] ??? translate *A* from domainB to domainA 
        img_B_id = self.g_AB(img_B)   # [self memo] ??? translate *B* from domainA to domainB 
        
        # Discriminators determines validity of traslated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        # Combined model trains generators to fool discriminators
        self.combined = tf.keras.models.Model(
            inputs=[img_A, img_B],
            outputs=[valid_A, valid_B, reconstr_A, reconstr_B, img_A_id, img_B_id]
        )
        self.combined.compile(
            loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'],  # Mean Squared Error, Mean Absolute Error
            loss_weights=[ self.lambda_validation, self.lambda_validation,
                         self.lambda_reconstr, self.lambda_reconstr,
                         self.lambda_id, self.lambda_id ],
            optimizer=tf.keras.optimizers.Adam(0.0002, 0.5)
        )
        self.d_A.trainable = True
        self.d_B.trainable = True

    def build_generator_unet(self):
        def downsample(layer_input, filters, f_size=4):
            d = tf.keras.layers.Conv2D(
                filters,
                kernel_size=f_size,
                strides=2,
                padding='same',
                kernel_initializer = self.weight_init  # [self memo] added by nitta
            )(layer_input)
            d = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(d)
            d = tf.keras.layers.Activation('relu')(d)
            return d
        def upsample(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            u = tf.keras.layers.UpSampling2D(size=2)(layer_input)
            u = tf.keras.layers.Conv2D(
                filters, 
                kernel_size=f_size, 
                strides=1, 
                padding='same',
                kernel_initializer = self.weight_init  # [self memo] added by nitta
            )(u)
            u = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(u)
            u = tf.keras.layers.Activation('relu')(u)
            if dropout_rate:
                u = tf.keras.layers.Dropout(dropout_rate)(u)
            u = tf.keras.layers.Concatenate()([u, skip_input])
            return u
        # Image input
        img = tf.keras.layers.Input(shape=self.img_shape)
        # Downsampling
        d1 = downsample(img, self.gen_n_filters)
        d2 = downsample(d1, self.gen_n_filters*2)
        d3 = downsample(d2, self.gen_n_filters*4)
        d4 = downsample(d3, self.gen_n_filters*8)

        # Upsampling
        u1 = upsample(d4, d3, self.gen_n_filters*4)
        u2 = upsample(u1, d2, self.gen_n_filters*2)
        u3 = upsample(u2, d1, self.gen_n_filters)

        u4 = tf.keras.layers.UpSampling2D(size=2)(u3)
        output_img = tf.keras.layers.Conv2D(
            self.channels, 
            kernel_size=4,
            strides=1, 
            padding='same',
            activation='tanh',
            kernel_initializer = self.weight_init  # [self memo] added by nitta
        )(u4)

        return tf.keras.models.Model(img, output_img)

    def build_generator_resnet(self):
        def conv7s1(layer_input, filters, final):
            y = ReflectionPadding2D(padding=(3,3))(layer_input)
            y = tf.keras.layers.Conv2D(
                filters,
                kernel_size=(7,7),
                strides=1,
                padding='valid',
                kernel_initializer=self.weight_init
            )(y)
            if final:
                y = tf.keras.layers.Activation('tanh')(y)
            else:
                y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
                y = tf.keras.layers.Activation('relu')(y)
            return y

        def downsample(layer_input, filters):
            y = tf.keras.layers.Conv2D(
                filters, 
                kernel_size=(3,3), 
                strides=2, 
                padding='same',
                kernel_initializer = self.weight_init
            )(layer_input)
            y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            y = tf.keras.layers.Activation('relu')(y)
            return y

        def residual(layer_input, filters):
            shortcut = layer_input
            y = ReflectionPadding2D(padding=(1,1))(layer_input)
            y = tf.keras.layers.Conv2D(
                filters,
                kernel_size=(3,3),
                strides=1,
                padding='valid',
                kernel_initializer=self.weight_init
            )(y)
            y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            y = tf.keras.layers.Activation('relu')(y)
            y = ReflectionPadding2D(padding=(1,1))(y)
            y = tf.keras.layers.Conv2D(
                filters,
                kernel_size=(3,3),
                strides=1,
                padding='valid',
                kernel_initializer=self.weight_init
            )(y)
            y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            return tf.keras.layers.add([shortcut, y])
          
        def upsample(layer_input, filters):
            y = tf.keras.layers.Conv2DTranspose(
                filters, 
                kernel_size=(3,3), 
                strides=2,
                padding='same',
                kernel_initializer=self.weight_init
            )(layer_input)
            y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            y = tf.keras.layers.Activation('relu')(y)
            return y

        # Image input
        img = tf.keras.layers.Input(shape=self.img_shape)

        y = img
        y = conv7s1(y, self.gen_n_filters, False)
        y = downsample(y, self.gen_n_filters * 2)
        y = downsample(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = upsample(y, self.gen_n_filters * 2)
        y = upsample(y, self.gen_n_filters)
        y = conv7s1(y, 3, True)
        output = y
        
        return tf.keras.models.Model(img, output)

    def build_discriminator(self):
        def conv4(layer_input, filters, stride=2, norm=True):
            y = tf.keras.layers.Conv2D(
                filters,
                kernel_size=(4,4),
                strides=stride,
                padding='same',
                kernel_initializer = self.weight_init
              )(layer_input)
            if norm:
                y = tf_addons.layers.InstanceNormalization(axis=-1, center=False, scale=False)(y)
            y = tf.keras.layers.LeakyReLU(0.2)(y)
            return y

        img = tf.keras.layers.Input(shape=self.img_shape)
        y = conv4(img, self.disc_n_filters, stride=2, norm=False)
        y = conv4(y, self.disc_n_filters*2, stride=2)
        y = conv4(y, self.disc_n_filters*4, stride=2)
        y = conv4(y, self.disc_n_filters*8, stride=1)
        output = tf.keras.layers.Conv2D(
            1,
            kernel_size=4,
            strides=1,
            padding='same',
            kernel_initializer=self.weight_init
        )(y)
        return tf.keras.models.Model(img, output)

    def train_discriminators(self, imgs_A, imgs_B, valid, fake):
        # Translate images to opposite domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        
        self.buffer_B.append(fake_B)
        self.buffer_A.append(fake_A)

        fake_A_rnd = random.sample(self.buffer_A, min(len(self.buffer_A), len(imgs_A))) # random sampling without replacement 
        fake_B_rnd = random.sample(self.buffer_B, min(len(self.buffer_B), len(imgs_B)))
        
        # Train the discriminators (original images=real / translated = fake)
        dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
        dA_loss_fake = self.d_A.train_on_batch(fake_A_rnd, fake)
        dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

        dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
        dB_loss_fake = self.d_B.train_on_batch(fake_B_rnd, fake)
        dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

        # Total discriminator loss
        d_loss_total = 0.5 * np.add(dA_loss, dB_loss)

        return (
            d_loss_total[0], 
            dA_loss[0], dA_loss_real[0], dA_loss_fake[0],
            dB_loss[0], dB_loss_real[0], dB_loss_fake[0],
            d_loss_total[1], 
            dA_loss[1], dA_loss_real[1], dA_loss_fake[1],
            dB_loss[1], dB_loss_real[1], dB_loss_fake[1]
        )

    def train_generators(self, imgs_A, imgs_B, valid):
        # Train the generators
        return self.combined.train_on_batch(
          [imgs_A, imgs_B], 
          [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B]
        )

    def train(self, data_loader, run_folder, epochs, test_A_file, test_B_file, batch_size=1, sample_batch_interval=50, sample_epoch_interval=1):
        start_time = datetime.datetime.now()
        # Adversarial loss ground truthes
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)

        for epoch in range(self.epoch, epochs):  # [self memo] ??? 2nd argument might be 'self.epoch+epochs'.
            for batch_i, (imgs_A, imgs_B) in enumerate(data_loader.load_batch()):
                d_loss = self.train_discriminators(imgs_A, imgs_B, valid, fake)
                g_loss = self.train_generators(imgs_A, imgs_B, valid)

                elapsed_time = datetime.datetime.now() - start_time

                # Plot the progress
                if (batch_i + 1) % sample_batch_interval == 0:  # [self memo] reduce displayed line
                    print(f'Epoch {self.epoch+1}/{epochs} Batch {batch_i+1}/{data_loader.n_batches} [D loss: {d_loss[0]:.3f} acc: {d_loss[7]:.3f}][G loss: {g_loss[0]:.3f} adv: {np.sum(g_loss[1:3]):.3f} recon: {np.sum(g_loss[3:5]):.3f} id: {np.sum(g_loss[5:7]):.3f} time: {elapsed_time:}')
                    self.sample_images(data_loader, (batch_i + 1), run_folder, test_A_file, test_B_file)
            ##########################
            # Start of change by nitta
            ##########################
            self.d_losses.append(d_loss)
            self.g_losses.append(g_loss)

            # if at save interval => save generated image samples
            if (self.epoch+1) % sample_epoch_interval == 0:
                self.sample_images(data_loader, batch_i, run_folder, test_A_file, test_B_file)
                self.combined.save_weights(os.path.join(run_folder, f'weights/weights-{self.epoch+1}.h5'))
                self.combined.save_weights(os.path.join(run_folder, 'weights/weights.h5'))
                self.save_model(run_folder)
            ##########################
            # End of change by nitta
            ##########################

            self.epoch += 1

    def sample_images(self, data_loader, batch_i, run_folder, test_A_file, test_B_file):
        r, c = 2, 4
        for p in range(2):
            if p == 1:
                imgs_A = data_loader.load_data(domain='A', batch_size=1, is_testing=True)
                imgs_B = data_loader.load_data(domain='B', batch_size=1, is_testing=True)
            else:
                imgs_A = data_loader.load_img(f'{data_loader.data_path}/{data_loader.dataset_name}/testA/{test_A_file}')
                imgs_B = data_loader.load_img(f'{data_loader.data_path}/{data_loader.dataset_name}/testB/{test_B_file}')
            # Translate images to the other domain
            fake_B = self.g_AB.predict(imgs_A)
            fake_A = self.g_BA.predict(imgs_B)
            # Translate back to original domain
            reconstr_A = self.g_BA.predict(fake_B)
            reconstr_B = self.g_AB.predict(fake_A)

            # ID the images
            id_A = self.g_BA.predict(imgs_A)
            id_B = self.g_AB.predict(imgs_B)

            gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, id_A, imgs_B, fake_A, reconstr_B, id_B])

            # Rescale images [-1, -1] --> [0, 1]
            gen_imgs = 0.5 *gen_imgs + 0.5
            gen_imgs = np.clip(gen_imgs, 0, 1)

            titles = ['Original', 'Translated', 'Reconstructed', 'ID']
            fig, ax = plt.subplots(r,c,figsize=(25, 12.5))
            cnt = 0
            for i in range(r):
                for j in range(c):
                    ax[i][j].imshow(gen_imgs[cnt])
                    ax[i][j].set_title(titles[j])
                    ax[i][j].axis('off')
                    cnt += 1
            fig.savefig(os.path.join(run_folder, f'images/{p}_{self.epoch}_{batch_i}.png'))
            plt.close()

    def plot_model(self, run_folder):
        tf.keras.utils.plot_model(
            self.d_A,
            to_file=os.path.join(run_folder, 'viz/d_A.png'),
            show_shapes = True,
            show_layer_names = True
        )
        tf.keras.utils.plot_model(
            self.d_B,
            to_file=os.path.join(run_folder, 'viz/d_B.png'),
            show_shapes = True,
            show_layer_names = True
        )
        tf.keras.utils.plot_model(
            self.g_BA,
            to_file=os.path.join(run_folder, 'viz/g_BA.png'),
            show_shapes = True,
            show_layer_names = True
        )
        tf.keras.utils.plot_model(
            self.g_AB,
            to_file=os.path.join(run_folder, 'viz/g_AB.png'),
            show_shapes = True,
            show_layer_names = True
        )

    def save(self, folder):
        with open(os.path.join(folder, 'params.pkl'), 'wb') as f:
            pkl.dump([
                self.input_dim,
                self.learning_rate,
                self.buffer_max_length,
                self.lambda_validation,
                self.lambda_reconstr,
                self.lambda_id,
                self.generator_type,
                self.gen_n_filters,
                self.disc_n_filters
              ], f)
        self.plot_model(folder)

    def save_model(self, run_folder):
        self.combined.save(os.path.join(run_folder, 'model.h5'))
        self.d_A.save(os.path.join(run_folder, 'd_A.h5'))
        self.d_B.save(os.path.join(run_folder, 'd_B.h5'))
        self.g_BA.save(os.path.join(run_folder, 'g_BA.h5'))
        self.g_AB.save(os.path.join(run_folder, 'g_AB.h5'))

    def load_weights(self, filepath):
        self.combined.load_weights(filepath)
        
In [9]:
gan = CycleGAN(
    input_dim = (IMAGE_SIZE, IMAGE_SIZE, 3),
    learning_rate = 0.0002,
    buffer_max_length = 50,
    lambda_validation = 1,
    lambda_reconstr = 10,
    lambda_id = 2,
    generator_type = 'unet',
    gen_n_filters = 32,
    disc_n_filters = 32
)
In [10]:
gan.g_BA.summary()
Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 64, 64, 32)   1568        input_4[0][0]                    
__________________________________________________________________________________________________
instance_normalization_13 (Inst (None, 64, 64, 32)   0           conv2d_18[0][0]                  
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 64, 64, 32)   0           instance_normalization_13[0][0]  
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 32, 32, 64)   32832       activation_7[0][0]               
__________________________________________________________________________________________________
instance_normalization_14 (Inst (None, 32, 32, 64)   0           conv2d_19[0][0]                  
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 32, 32, 64)   0           instance_normalization_14[0][0]  
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 16, 16, 128)  131200      activation_8[0][0]               
__________________________________________________________________________________________________
instance_normalization_15 (Inst (None, 16, 16, 128)  0           conv2d_20[0][0]                  
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 16, 16, 128)  0           instance_normalization_15[0][0]  
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 8, 8, 256)    524544      activation_9[0][0]               
__________________________________________________________________________________________________
instance_normalization_16 (Inst (None, 8, 8, 256)    0           conv2d_21[0][0]                  
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 8, 8, 256)    0           instance_normalization_16[0][0]  
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 16, 16, 256)  0           activation_10[0][0]              
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 16, 16, 128)  524416      up_sampling2d_4[0][0]            
__________________________________________________________________________________________________
instance_normalization_17 (Inst (None, 16, 16, 128)  0           conv2d_22[0][0]                  
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 16, 16, 128)  0           instance_normalization_17[0][0]  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 16, 16, 256)  0           activation_11[0][0]              
                                                                 activation_9[0][0]               
__________________________________________________________________________________________________
up_sampling2d_5 (UpSampling2D)  (None, 32, 32, 256)  0           concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 32, 32, 64)   262208      up_sampling2d_5[0][0]            
__________________________________________________________________________________________________
instance_normalization_18 (Inst (None, 32, 32, 64)   0           conv2d_23[0][0]                  
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 32, 32, 64)   0           instance_normalization_18[0][0]  
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 32, 32, 128)  0           activation_12[0][0]              
                                                                 activation_8[0][0]               
__________________________________________________________________________________________________
up_sampling2d_6 (UpSampling2D)  (None, 64, 64, 128)  0           concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 64, 64, 32)   65568       up_sampling2d_6[0][0]            
__________________________________________________________________________________________________
instance_normalization_19 (Inst (None, 64, 64, 32)   0           conv2d_24[0][0]                  
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 64, 64, 32)   0           instance_normalization_19[0][0]  
__________________________________________________________________________________________________
concatenate_5 (Concatenate)     (None, 64, 64, 64)   0           activation_13[0][0]              
                                                                 activation_7[0][0]               
__________________________________________________________________________________________________
up_sampling2d_7 (UpSampling2D)  (None, 128, 128, 64) 0           concatenate_5[0][0]              
__________________________________________________________________________________________________
conv2d_25 (Conv2D)              (None, 128, 128, 3)  3075        up_sampling2d_7[0][0]            
==================================================================================================
Total params: 1,545,411
Trainable params: 1,545,411
Non-trainable params: 0
__________________________________________________________________________________________________
In [11]:
gan.g_AB.summary()
Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 64, 64, 32)   1568        input_3[0][0]                    
__________________________________________________________________________________________________
instance_normalization_6 (Insta (None, 64, 64, 32)   0           conv2d_10[0][0]                  
__________________________________________________________________________________________________
activation (Activation)         (None, 64, 64, 32)   0           instance_normalization_6[0][0]   
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 32, 32, 64)   32832       activation[0][0]                 
__________________________________________________________________________________________________
instance_normalization_7 (Insta (None, 32, 32, 64)   0           conv2d_11[0][0]                  
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 32, 32, 64)   0           instance_normalization_7[0][0]   
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 16, 16, 128)  131200      activation_1[0][0]               
__________________________________________________________________________________________________
instance_normalization_8 (Insta (None, 16, 16, 128)  0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 16, 16, 128)  0           instance_normalization_8[0][0]   
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 8, 8, 256)    524544      activation_2[0][0]               
__________________________________________________________________________________________________
instance_normalization_9 (Insta (None, 8, 8, 256)    0           conv2d_13[0][0]                  
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 8, 8, 256)    0           instance_normalization_9[0][0]   
__________________________________________________________________________________________________
up_sampling2d (UpSampling2D)    (None, 16, 16, 256)  0           activation_3[0][0]               
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 16, 16, 128)  524416      up_sampling2d[0][0]              
__________________________________________________________________________________________________
instance_normalization_10 (Inst (None, 16, 16, 128)  0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 16, 16, 128)  0           instance_normalization_10[0][0]  
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 16, 16, 256)  0           activation_4[0][0]               
                                                                 activation_2[0][0]               
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 32, 32, 256)  0           concatenate[0][0]                
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 32, 32, 64)   262208      up_sampling2d_1[0][0]            
__________________________________________________________________________________________________
instance_normalization_11 (Inst (None, 32, 32, 64)   0           conv2d_15[0][0]                  
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 32, 32, 64)   0           instance_normalization_11[0][0]  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 32, 32, 128)  0           activation_5[0][0]               
                                                                 activation_1[0][0]               
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 64, 64, 128)  0           concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 64, 64, 32)   65568       up_sampling2d_2[0][0]            
__________________________________________________________________________________________________
instance_normalization_12 (Inst (None, 64, 64, 32)   0           conv2d_16[0][0]                  
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 64, 64, 32)   0           instance_normalization_12[0][0]  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 64, 64, 64)   0           activation_6[0][0]               
                                                                 activation[0][0]                 
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 128, 128, 64) 0           concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 128, 128, 3)  3075        up_sampling2d_3[0][0]            
==================================================================================================
Total params: 1,545,411
Trainable params: 1,545,411
Non-trainable params: 0
__________________________________________________________________________________________________
In [12]:
gan.d_A.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 128, 128, 3)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 64, 64, 32)        1568      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 32, 32, 64)        32832     
_________________________________________________________________
instance_normalization (Inst (None, 32, 32, 64)        0         
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 16, 16, 128)       131200    
_________________________________________________________________
instance_normalization_1 (In (None, 16, 16, 128)       0         
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 16, 256)       524544    
_________________________________________________________________
instance_normalization_2 (In (None, 16, 16, 256)       0         
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 16, 16, 256)       0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 16, 16, 1)         4097      
=================================================================
Total params: 694,241
Trainable params: 694,241
Non-trainable params: 0
_________________________________________________________________
In [13]:
gan.d_B.summary()
Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 128, 128, 3)]     0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 64, 64, 32)        1568      
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 32, 32, 64)        32832     
_________________________________________________________________
instance_normalization_3 (In (None, 32, 32, 64)        0         
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 16, 16, 128)       131200    
_________________________________________________________________
instance_normalization_4 (In (None, 16, 16, 128)       0         
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 16, 16, 256)       524544    
_________________________________________________________________
instance_normalization_5 (In (None, 16, 16, 256)       0         
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 16, 16, 256)       0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 16, 16, 1)         4097      
=================================================================
Total params: 694,241
Trainable params: 694,241
Non-trainable params: 0
_________________________________________________________________

Train

In [14]:
import os

def md(p):
    if not os.path.isdir(p):
        os.makedirs(p)
In [15]:
SECTION = 'paint'
RUN_FOLDER = f'run/{SECTION}/{DATA_NAME}'

md(RUN_FOLDER)
md(f'{RUN_FOLDER}/viz')
md(f'{RUN_FOLDER}/images')
md(f'{RUN_FOLDER}/weights')
In [16]:
mode = 'build'
In [17]:
if mode == 'build':
    gan.save(RUN_FOLDER)
else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))
In [18]:
BATCH_SIZE=1
EPOCHS = 1 # 200
PRINT_EVERY_N_EPOCHS = 1

TEST_A_FILE = 'n07740461_14740.jpg'
TEST_B_FILE = 'n07749192_4241.jpg'
In [19]:
# [自習]
test_A_img = data_loader.load_img(f'{data_loader.data_path}/{data_loader.dataset_name}/testA/{TEST_A_FILE}')
test_B_img = data_loader.load_img(f'{data_loader.data_path}/{data_loader.dataset_name}/testB/{TEST_B_FILE}')

print(test_A_img.shape)
print(test_B_img.shape)
(1, 128, 128, 3)
(1, 128, 128, 3)
In [20]:
%matplotlib inline
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1,2,figsize=(6*2,6))

ax[0].imshow(np.clip(test_A_img[0]/2 + 1,0.0,1.0))
ax[0].axis('off')

ax[1].imshow(np.clip(test_B_img[0]/2+1, 0.0, 1.0))
ax[1].axis('off')

plt.show()
In [21]:
# [自習] [-1,1]の画像の配列を1行に表示する関数
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

def showImages(imgs):
    n = len(imgs)
    fig, ax = plt.subplots(1,n,figsize=(IMAGE_SIZE/20, IMAGE_SIZE*n/20))
    for i in range(n):
        ax[i].imshow(np.clip((imgs[i]+1)/2, 0, 1))
        ax[i].axis('off')
    plt.show()
In [22]:
# [自習] テスト用の画像を処理してみる。
# 何も学習していない時点での gan.g_AB(), gan.g_BA() の出力を確認してみる。
test_imgs_A = data_loader.load_img(f'{data_loader.data_path}/{data_loader.dataset_name}/testA/{TEST_A_FILE}') # [-1, 1]
test_imgs_B = data_loader.load_img(f'{data_loader.data_path}/{data_loader.dataset_name}/testB/{TEST_B_FILE}') # [-1, 1]

fake_imgs_B = gan.g_AB(test_imgs_A)
fake_imgs_A = gan.g_BA(test_imgs_B)

showImages([fake_imgs_B[0], fake_imgs_A[0]])
In [23]:
# [自習] gan.train() のコードを変更した。
# gan.epoch < epochs である限り train() でのループを行うように変更した。
# sampling 周期も何 epoch 毎に行うかの指定に変更した。
import datetime
gan.train(
    data_loader,
    run_folder = RUN_FOLDER,
    epochs=1, ### [self memo] step by step
    test_A_file = TEST_A_FILE,
    test_B_file = TEST_B_FILE,
    batch_size = BATCH_SIZE,
    sample_batch_interval = 100,
    sample_epoch_interval = 1# [self memo] sampling once for each epoch
)
Epoch 1/1 Batch 100/995 [D loss: 0.379 acc: 0.417][G loss: 7.727 adv: 0.651 recon: 0.591 id: 0.583 time: 0:00:37.452171
Epoch 1/1 Batch 200/995 [D loss: 0.230 acc: 0.657][G loss: 8.067 adv: 0.527 recon: 0.630 id: 0.618 time: 0:01:03.235556
Epoch 1/1 Batch 300/995 [D loss: 0.219 acc: 0.649][G loss: 8.255 adv: 0.795 recon: 0.619 id: 0.633 time: 0:01:28.842880
Epoch 1/1 Batch 400/995 [D loss: 0.274 acc: 0.534][G loss: 6.428 adv: 0.783 recon: 0.469 id: 0.480 time: 0:01:55.014706
Epoch 1/1 Batch 500/995 [D loss: 0.265 acc: 0.554][G loss: 8.099 adv: 0.604 recon: 0.644 id: 0.525 time: 0:02:20.677731
Epoch 1/1 Batch 600/995 [D loss: 0.134 acc: 0.853][G loss: 5.759 adv: 0.966 recon: 0.399 id: 0.401 time: 0:02:46.608460
Epoch 1/1 Batch 700/995 [D loss: 0.199 acc: 0.655][G loss: 5.708 adv: 0.688 recon: 0.424 id: 0.390 time: 0:03:12.370254
Epoch 1/1 Batch 800/995 [D loss: 0.181 acc: 0.748][G loss: 6.604 adv: 0.933 recon: 0.478 id: 0.444 time: 0:03:37.907786
Epoch 1/1 Batch 900/995 [D loss: 0.162 acc: 0.795][G loss: 5.899 adv: 1.193 recon: 0.386 id: 0.422 time: 0:04:03.919202
In [24]:
print(gan.epoch)
1

[自分へのメモ]

gan.g_AB.predict(imgs_A) の結果が真っ黒になっている。たが、DataLoader クラスのバグのためであり解決した[2021/10/16追記]。

下の画像で、

  • 0行1列 ... gan.g_AB.predict(test_A_img)
  • 1行2列 ... gan_g_AB.predict(gan.g_BA.predict(test_B_img))
  • 1行3列 ... gan_g_AB.predict(imgs_B)

In [25]:
# [自習] gan.sample_image() が出力した画像を1つ表示してみる。
%matplotlib inline
import matplotlib.pyplot as plt
import imageio

sample_image = imageio.imread(f'{RUN_FOLDER}/images/0_0_994.png', pilmode='RGB')

fig, ax = plt.subplots(1,1,figsize=(25, 12.5))
ax.imshow(sample_image)
ax.axis('off')

plt.show()