Google Colab 上で TPU を使う (custom train)

https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/keras_flowers_customtrainloop_tf2.1.ipynb

This sample shows how to use the distribution strategy APIs when writing a custom training loop on TPU:

  • instantiate a TPUStrategy()
  • create the model and all other trainin objects in a strategy scope with strategy.scope(): ...
  • distribute the dataset with strategy.experimental_distribute_dataset(ds)
  • run the training step distributed with strategy.run(step_fn)
  • aggregate results returned by distributed workers with strategy.reduce(...)
In [1]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
import re, time
print("Tensorflow version " + tf.__version__)
Tensorflow version 2.7.0
In [2]:
try: # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # detect GPUs
    strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    #strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
    #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines

print("Number of accelerators: ", strategy.num_replicas_in_sync)
INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.
INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.
INFO:tensorflow:Initializing the TPU system: grpc://10.58.238.58:8470
INFO:tensorflow:Initializing the TPU system: grpc://10.58.238.58:8470
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
Number of accelerators:  8
In [3]:
EPOCHS = 60

if strategy.num_replicas_in_sync == 1: # GPU
    BATCH_SIZE = 16
    VALIDATION_BATCH_SIZE = 16
    START_LR = 0.01
    MAX_LR = 0.01
    MIN_LR = 0.01
    LR_RAMP = 0 # epochs
    LR_SUSTAIN = 0 #epochs
    LR_DECAY = 1
    
elif strategy.num_replicas_in_sync == 8: # single TPU
    BATCH_SIZE = 16 * strategy.num_replicas_in_sync # use 32 on TPUv3
    VALIDATION_BATCH_SIZE = 256
    START_LR = 0.01
    MAX_LR = 0.01 * strategy.num_replicas_in_sync
    MIN_LR = 0.001
    LR_RAMP = 0 # epochs
    LR_SUSTAIN = 13 # epochs
    LR_DECAY = .95

else: # TPU pod
    BATCH_SIZE = 16 * strategy.num_replicas_in_sync  # Gobal batch size.
    VALIDATION_BATCH_SIZE = 256
    START_LR = 0.06
    MAX_LR = 0.012 * strategy.num_replicas_in_sync
    MIN_LR = 0.01
    LR_RAMP = 5 # epochs
    LR_SUSTAIN = 8 # epochs
    LR_DECAY = 0.95

CLASSES = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] # do not change, maps to the labels in the data (folder names)

IMAGE_SIZE = [331, 331] # supported images sizes: 192x192, 331x331, 512,512
                        # make sure you load the appropriate dataset on the next line
#GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-192x192-2/*.tfrec'
GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-331x331/*.tfrec'
#GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-512x512/*.tfrec'
VALIDATION_SPLIT = 0.19

def lrfn(epoch):
    if LR_RAMP > 0 and epoch < LR_RAMP:  # linear ramp from START_LR to MAX_LR
        lr = (MAX_LR - START_LR)/(LR_RAMP*1.0) * epoch + START_LR
    elif epoch < LR_RAMP + LR_SUSTAIN:  # constant ar MAX_LR
        lr = MAX_LR
    else:  # exponential decay from MAX_LR to MIN_LR
        lr = (MAX_LR - MIN_LR) * LR_DECAY**(epoch-LR_RAMP-LR_SUSTAIN) + MIN_LR
    return lr
    
@tf.function
def lrfn_tffun(epoch):
    return lrfn(epoch)

print("Learning rate schedule:")
rng = [i for i in range(EPOCHS)]
plt.plot(rng, [lrfn(x) for x in rng])
plt.show()
Learning rate schedule:
In [5]:
#@title display utilities [RUN ME]

def dataset_to_numpy_util(dataset, N):
  dataset = dataset.batch(N)
  
  if tf.executing_eagerly():
    # In eager mode, iterate in the Datset directly.
    for images, labels in dataset:
      numpy_images = images.numpy()
      numpy_labels = labels.numpy()
      break;
      
  else: # In non-eager mode, must get the TF note that 
        # yields the nextitem and run it in a tf.Session.
    get_next_item = dataset.make_one_shot_iterator().get_next()
    with tf.Session() as ses:
      numpy_images, numpy_labels = ses.run(get_next_item)

  return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
  label = np.argmax(label, axis=-1)  # one-hot to class number
  correct_label = np.argmax(correct_label, axis=-1) # one-hot to class number
  correct = (label == correct_label)
  return "{} [{}{}{}]".format(CLASSES[label], str(correct), ', shoud be ' if not correct else '',
                              CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False):
    plt.subplot(subplot)
    plt.axis('off')
    plt.imshow(image)
    plt.title(title, fontsize=16, color='red' if red else 'black')
    return subplot+1
  
def display_9_images_from_dataset(dataset):
  subplot=331
  plt.figure(figsize=(13,13))
  images, labels = dataset_to_numpy_util(dataset, 9)
  for i, image in enumerate(images):
    title = CLASSES[np.argmax(labels[i], axis=-1)]
    subplot = display_one_flower(image, title, subplot)
    if i >= 8:
      break;
              
  #plt.tight_layout() # bug in tight layout in this version of matplotlib
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()
  
def display_9_images_with_predictions(images, predictions, labels):
  subplot=331
  plt.figure(figsize=(13,13))
  for i, image in enumerate(images):
    title, correct = title_from_label_and_target(predictions[i], labels[i])
    subplot = display_one_flower(image, title, subplot, not correct)
    if i >= 8:
      break;
              
  #plt.tight_layout() # bug in tight layout in this version of matplotlib
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()
  
def display_training_curves(training, validation, title, subplot):
  if subplot%10==1: # set up the subplots on the first call
    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
    #plt.tight_layout() # bug in tight layout in this version of matplotlib
  ax = plt.subplot(subplot)
  ax.set_facecolor('#F8F8F8')
  ax.plot(training)
  ax.plot(validation)
  ax.set_title('model '+ title)
  ax.set_ylabel(title)
  #ax.set_ylim(0.28,1.05)
  ax.set_xlabel('epoch')
  ax.legend(['train', 'valid.'])

Read images and labels from TFRecords

In [6]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

def count_data_items(filenames):
    # trick: the number of data items is written in the name of
    # the .tfrec files a flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return int(np.sum(n))

def data_augment(image, one_hot_class):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_saturation(image, 0, 2)
    return image, one_hot_class

def read_tfrecord(example):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means scalar
        "one_hot_class": tf.io.VarLenFeature(tf.float32),
    }
    example = tf.io.parse_single_example(example, features)
    image = tf.image.decode_jpeg(example['image'], channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # force the image size so that the shape of the tensor is known to Tensorflow
    class_label = tf.cast(example['class'], tf.int32)
    one_hot_class = tf.sparse.to_dense(example['one_hot_class'])
    one_hot_class = tf.reshape(one_hot_class, [5])
    return image, one_hot_class

def load_dataset(filenames):
    # read from TFRecords. For optimal performance, use TFRecordDataset with
    # num_parallel_calls=AUTOTUNE to read from multiple TFRecord files at once
    # band set the option experimental_deterministic = False
    # to allow order-altering optimizations.

    opt = tf.data.Options()
    opt.experimental_deterministic = False

    dataset = tf.data.Dataset.from_tensor_slices(filenames).with_options(opt)
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=16) # can be AUTOTUNE in TF 2.1
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

def batch_dataset(filenames, batch_size, train):
    dataset = load_dataset(filenames)
    n = count_data_items(filenames)
    
    if train:
        dataset = dataset.repeat() # training dataset must repeat
        dataset = dataset.map(data_augment, num_parallel_calls=AUTOTUNE)
        dataset = dataset.shuffle(2048)
    else:
        # usually fewer validation files than workers so disable FILE auto-sharding on validation
        if strategy.num_replicas_in_sync > 1: # option not useful if there is no sharding (not harmful either)
            opt = tf.data.Options()
            opt.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
            dataset = dataset.with_options(opt)
        # validation dataset does not need to repeat
        # also no need to shuffle or apply data augmentation
    if train:
        dataset = dataset.batch(batch_size)
    else:
        # little wrinkle: drop_remainder is NOT necessary but validation on the last
        # partial batch sometimes returns a "nan" loss (probably a bug). You can remove
        # this if you do not care about the validatoin loss.
        dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(AUTOTUNE) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset, n//batch_size

def get_training_dataset(filenames):
    dataset, steps = batch_dataset(filenames, BATCH_SIZE, train=True)
    return dataset, steps

def get_validation_dataset(filenames):
    dataset, steps = batch_dataset(filenames, VALIDATION_BATCH_SIZE, train=False)
    return dataset, steps
In [7]:
# instantiate datasets
filenames = tf.io.gfile.glob(GCS_PATTERN)
split = len(filenames) - int(len(filenames) * VALIDATION_SPLIT)
train_filenames = filenames[:split]
valid_filenames = filenames[split:]

training_dataset, steps_per_epoch = get_training_dataset(train_filenames)
validation_dataset, validation_steps = get_validation_dataset(valid_filenames)

print("TRAINING   IMAGES: ", count_data_items(train_filenames), ", STEPS PER EPOCH: ", steps_per_epoch)
print("VALIDATION IMAGES: ", count_data_items(valid_filenames), ", STEPS PER EPOCH: ", validation_steps)

# numpy data to test predictions
some_flowers, some_labels = dataset_to_numpy_util(load_dataset(valid_filenames), 160)
TRAINING   IMAGES:  2990 , STEPS PER EPOCH:  23
VALIDATION IMAGES:  680 , STEPS PER EPOCH:  2
In [8]:
display_9_images_from_dataset(load_dataset(train_filenames))

The model: Squeezenet with 12 layers

In [9]:
def create_model():
    bnmomemtum=0.9 # with only a handful of batches per epoch, the batch norm running average period must be lowered
    def fire(x, squeeze, expand):
        y  = tf.keras.layers.Conv2D(filters=squeeze, kernel_size=1, activation=None, padding='same', use_bias=False)(x)
        y = tf.keras.layers.BatchNormalization(momentum=bnmomemtum, scale=False, center=True)(y)
        y = tf.keras.layers.Activation('relu')(y)
        y1 = tf.keras.layers.Conv2D(filters=expand//2, kernel_size=1, activation=None, padding='same', use_bias=False)(y)
        y1 = tf.keras.layers.BatchNormalization(momentum=bnmomemtum, scale=False, center=True)(y1)
        y1 = tf.keras.layers.Activation('relu')(y1)
        y3 = tf.keras.layers.Conv2D(filters=expand//2, kernel_size=3, activation=None, padding='same', use_bias=False)(y)
        y3 = tf.keras.layers.BatchNormalization(momentum=bnmomemtum, scale=False, center=True)(y3)
        y3 = tf.keras.layers.Activation('relu')(y3)
        return tf.keras.layers.concatenate([y1, y3])

    def fire_module(squeeze, expand):
        return lambda x: fire(x, squeeze, expand)

    x = tf.keras.layers.Input(shape=(*IMAGE_SIZE, 3)) # input is 331x331 pixels RGB
    y = tf.keras.layers.Conv2D(kernel_size=3, filters=32, padding='same', use_bias=True, activation='relu')(x)
    y = tf.keras.layers.BatchNormalization(momentum=bnmomemtum)(y)
    y = fire_module(24, 48)(y)
    y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
    y = fire_module(48, 96)(y)
    y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
    y = fire_module(64, 128)(y)
    y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
    y = fire_module(48, 96)(y)
    y = tf.keras.layers.MaxPooling2D(pool_size=2)(y)
    y = fire_module(24, 48)(y)
    y = tf.keras.layers.GlobalAveragePooling2D()(y)
    y = tf.keras.layers.Dropout(0.4)(y)
    y = tf.keras.layers.Dense(5, activation='softmax')(y)
    return tf.keras.Model(x, y)

# Custom learning rate schedule
class LRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __call__(self, step):
            return lrfn_tffun(epoch=step//steps_per_epoch)

Instantiate all objects in the strategy scope

In [10]:
with strategy.scope():
    model = create_model()
    
    # Instiate optimizer with learning rate schedule
    optimizer = tf.keras.optimizers.SGD(nesterov=True, momentum=0.9, learning_rate=LRSchedule())
    train_accuracy = tf.keras.metrics.CategoricalAccuracy()
    valid_accuracy = tf.keras.metrics.CategoricalAccuracy()
    loss_fn = lambda labels, probabilities: tf.reduce_mean(tf.keras.losses.categorical_crossentropy(labels, probabilities))
        
    model.summary()
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 331, 331, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 331, 331, 32  896         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 331, 331, 32  128        ['conv2d[0][0]']                 
 alization)                     )                                                                 
                                                                                                  
 conv2d_1 (Conv2D)              (None, 331, 331, 24  768         ['batch_normalization[0][0]']    
                                )                                                                 
                                                                                                  
 batch_normalization_1 (BatchNo  (None, 331, 331, 24  72         ['conv2d_1[0][0]']               
 rmalization)                   )                                                                 
                                                                                                  
 activation (Activation)        (None, 331, 331, 24  0           ['batch_normalization_1[0][0]']  
                                )                                                                 
                                                                                                  
 conv2d_2 (Conv2D)              (None, 331, 331, 24  576         ['activation[0][0]']             
                                )                                                                 
                                                                                                  
 conv2d_3 (Conv2D)              (None, 331, 331, 24  5184        ['activation[0][0]']             
                                )                                                                 
                                                                                                  
 batch_normalization_2 (BatchNo  (None, 331, 331, 24  72         ['conv2d_2[0][0]']               
 rmalization)                   )                                                                 
                                                                                                  
 batch_normalization_3 (BatchNo  (None, 331, 331, 24  72         ['conv2d_3[0][0]']               
 rmalization)                   )                                                                 
                                                                                                  
 activation_1 (Activation)      (None, 331, 331, 24  0           ['batch_normalization_2[0][0]']  
                                )                                                                 
                                                                                                  
 activation_2 (Activation)      (None, 331, 331, 24  0           ['batch_normalization_3[0][0]']  
                                )                                                                 
                                                                                                  
 concatenate (Concatenate)      (None, 331, 331, 48  0           ['activation_1[0][0]',           
                                )                                 'activation_2[0][0]']           
                                                                                                  
 max_pooling2d (MaxPooling2D)   (None, 165, 165, 48  0           ['concatenate[0][0]']            
                                )                                                                 
                                                                                                  
 conv2d_4 (Conv2D)              (None, 165, 165, 48  2304        ['max_pooling2d[0][0]']          
                                )                                                                 
                                                                                                  
 batch_normalization_4 (BatchNo  (None, 165, 165, 48  144        ['conv2d_4[0][0]']               
 rmalization)                   )                                                                 
                                                                                                  
 activation_3 (Activation)      (None, 165, 165, 48  0           ['batch_normalization_4[0][0]']  
                                )                                                                 
                                                                                                  
 conv2d_5 (Conv2D)              (None, 165, 165, 48  2304        ['activation_3[0][0]']           
                                )                                                                 
                                                                                                  
 conv2d_6 (Conv2D)              (None, 165, 165, 48  20736       ['activation_3[0][0]']           
                                )                                                                 
                                                                                                  
 batch_normalization_5 (BatchNo  (None, 165, 165, 48  144        ['conv2d_5[0][0]']               
 rmalization)                   )                                                                 
                                                                                                  
 batch_normalization_6 (BatchNo  (None, 165, 165, 48  144        ['conv2d_6[0][0]']               
 rmalization)                   )                                                                 
                                                                                                  
 activation_4 (Activation)      (None, 165, 165, 48  0           ['batch_normalization_5[0][0]']  
                                )                                                                 
                                                                                                  
 activation_5 (Activation)      (None, 165, 165, 48  0           ['batch_normalization_6[0][0]']  
                                )                                                                 
                                                                                                  
 concatenate_1 (Concatenate)    (None, 165, 165, 96  0           ['activation_4[0][0]',           
                                )                                 'activation_5[0][0]']           
                                                                                                  
 max_pooling2d_1 (MaxPooling2D)  (None, 82, 82, 96)  0           ['concatenate_1[0][0]']          
                                                                                                  
 conv2d_7 (Conv2D)              (None, 82, 82, 64)   6144        ['max_pooling2d_1[0][0]']        
                                                                                                  
 batch_normalization_7 (BatchNo  (None, 82, 82, 64)  192         ['conv2d_7[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 activation_6 (Activation)      (None, 82, 82, 64)   0           ['batch_normalization_7[0][0]']  
                                                                                                  
 conv2d_8 (Conv2D)              (None, 82, 82, 64)   4096        ['activation_6[0][0]']           
                                                                                                  
 conv2d_9 (Conv2D)              (None, 82, 82, 64)   36864       ['activation_6[0][0]']           
                                                                                                  
 batch_normalization_8 (BatchNo  (None, 82, 82, 64)  192         ['conv2d_8[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 batch_normalization_9 (BatchNo  (None, 82, 82, 64)  192         ['conv2d_9[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 activation_7 (Activation)      (None, 82, 82, 64)   0           ['batch_normalization_8[0][0]']  
                                                                                                  
 activation_8 (Activation)      (None, 82, 82, 64)   0           ['batch_normalization_9[0][0]']  
                                                                                                  
 concatenate_2 (Concatenate)    (None, 82, 82, 128)  0           ['activation_7[0][0]',           
                                                                  'activation_8[0][0]']           
                                                                                                  
 max_pooling2d_2 (MaxPooling2D)  (None, 41, 41, 128)  0          ['concatenate_2[0][0]']          
                                                                                                  
 conv2d_10 (Conv2D)             (None, 41, 41, 48)   6144        ['max_pooling2d_2[0][0]']        
                                                                                                  
 batch_normalization_10 (BatchN  (None, 41, 41, 48)  144         ['conv2d_10[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_9 (Activation)      (None, 41, 41, 48)   0           ['batch_normalization_10[0][0]'] 
                                                                                                  
 conv2d_11 (Conv2D)             (None, 41, 41, 48)   2304        ['activation_9[0][0]']           
                                                                                                  
 conv2d_12 (Conv2D)             (None, 41, 41, 48)   20736       ['activation_9[0][0]']           
                                                                                                  
 batch_normalization_11 (BatchN  (None, 41, 41, 48)  144         ['conv2d_11[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 batch_normalization_12 (BatchN  (None, 41, 41, 48)  144         ['conv2d_12[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_10 (Activation)     (None, 41, 41, 48)   0           ['batch_normalization_11[0][0]'] 
                                                                                                  
 activation_11 (Activation)     (None, 41, 41, 48)   0           ['batch_normalization_12[0][0]'] 
                                                                                                  
 concatenate_3 (Concatenate)    (None, 41, 41, 96)   0           ['activation_10[0][0]',          
                                                                  'activation_11[0][0]']          
                                                                                                  
 max_pooling2d_3 (MaxPooling2D)  (None, 20, 20, 96)  0           ['concatenate_3[0][0]']          
                                                                                                  
 conv2d_13 (Conv2D)             (None, 20, 20, 24)   2304        ['max_pooling2d_3[0][0]']        
                                                                                                  
 batch_normalization_13 (BatchN  (None, 20, 20, 24)  72          ['conv2d_13[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_12 (Activation)     (None, 20, 20, 24)   0           ['batch_normalization_13[0][0]'] 
                                                                                                  
 conv2d_14 (Conv2D)             (None, 20, 20, 24)   576         ['activation_12[0][0]']          
                                                                                                  
 conv2d_15 (Conv2D)             (None, 20, 20, 24)   5184        ['activation_12[0][0]']          
                                                                                                  
 batch_normalization_14 (BatchN  (None, 20, 20, 24)  72          ['conv2d_14[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 batch_normalization_15 (BatchN  (None, 20, 20, 24)  72          ['conv2d_15[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_13 (Activation)     (None, 20, 20, 24)   0           ['batch_normalization_14[0][0]'] 
                                                                                                  
 activation_14 (Activation)     (None, 20, 20, 24)   0           ['batch_normalization_15[0][0]'] 
                                                                                                  
 concatenate_4 (Concatenate)    (None, 20, 20, 48)   0           ['activation_13[0][0]',          
                                                                  'activation_14[0][0]']          
                                                                                                  
 global_average_pooling2d (Glob  (None, 48)          0           ['concatenate_4[0][0]']          
 alAveragePooling2D)                                                                              
                                                                                                  
 dropout (Dropout)              (None, 48)           0           ['global_average_pooling2d[0][0]'
                                                                 ]                                
                                                                                                  
 dense (Dense)                  (None, 5)            245         ['dropout[0][0]']                
                                                                                                  
==================================================================================================
Total params: 119,365
Trainable params: 118,053
Non-trainable params: 1,312
__________________________________________________________________________________________________

Step functions

In [11]:
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        probabilities = model(images, training=True)
        loss = loss_fn(labels, probabilities)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_accuracy.update_state(labels, probabilities)
    return loss

@tf.function
def valid_step(images, labels):
    probabilities = model(images, training=False)
    loss = loss_fn(labels, probabilities)
    valid_accuracy.update_state(labels, probabilities)
    return loss

Custom training loop

In [12]:
# distribute the datset according to the strategy
train_dist_ds = strategy.experimental_distribute_dataset(training_dataset)
valid_dist_ds = strategy.experimental_distribute_dataset(validation_dataset)

print("Steps per epoch: ", steps_per_epoch)

epoch = 0
train_losses=[]
start_time = epoch_start_time = time.time()

for step, (images, labels) in enumerate(train_dist_ds):

    # batch losses from all replicas
    loss = strategy.run(train_step, args=(images, labels))
    # reduced to a single number both across replicas and across the bacth size
    loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss, axis=None)
    # or use strategy.experimental_local_results(loss) to access the raw set of losses returned from all replicas

    # validation run at the end of each epoch
    if ((step+1) // steps_per_epoch) > epoch:
        valid_loss = []
        for image, labels in valid_dist_ds:
            batch_loss = strategy.run(valid_step, args=(image, labels)) # just one batch
            batch_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, batch_loss, axis=None)
            valid_loss.append(batch_loss.numpy())
        valid_loss = np.mean(valid_loss)

        epoch = (step+1) // steps_per_epoch
        epoch_time = time.time() - epoch_start_time
        print('\nEPOCH: ', epoch)
        print('time: {:0.1f}s'.format(epoch_time),
              ', loss: ', loss.numpy(),
              ', accuracy_: ', train_accuracy.result().numpy(),
              ', val_loss: ', valid_loss,
              ', val_acc_: ', valid_accuracy.result().numpy(),
              ', lr: ', lrfn(epoch)
             )
        
        epoch_start_time = time.time()
        train_accuracy.reset_states()
        valid_accuracy.reset_states()
        if epoch >= EPOCHS:
            break
            
    train_losses.append(loss)
    print('=', end='')
    
training_time = time.time() - start_time
print("TOTAL TRAINING TIME: {:0.1f}s".format(training_time))
Steps per epoch:  23
======================
EPOCH:  1
time: 41.1s , loss:  1.3591349 , accuracy_:  0.33491847 , val_loss:  22.189926 , val_acc_:  0.25390625 , lr:  0.08
=======================
EPOCH:  2
time: 7.9s , loss:  1.1927705 , accuracy_:  0.43953803 , val_loss:  1.7011714 , val_acc_:  0.4453125 , lr:  0.08
=======================
EPOCH:  3
time: 7.8s , loss:  1.3540113 , accuracy_:  0.47758153 , val_loss:  1.4239981 , val_acc_:  0.50390625 , lr:  0.08
=======================
EPOCH:  4
time: 7.7s , loss:  1.1380371 , accuracy_:  0.49524456 , val_loss:  1.3725808 , val_acc_:  0.52734375 , lr:  0.08
=======================
EPOCH:  5
time: 7.7s , loss:  1.1323333 , accuracy_:  0.53566575 , val_loss:  1.3797457 , val_acc_:  0.5390625 , lr:  0.08
=======================
EPOCH:  6
time: 7.9s , loss:  1.1515396 , accuracy_:  0.5461956 , val_loss:  1.0639442 , val_acc_:  0.5859375 , lr:  0.08
=======================
EPOCH:  7
time: 7.6s , loss:  1.0200263 , accuracy_:  0.55672556 , val_loss:  1.0803308 , val_acc_:  0.5800781 , lr:  0.08
=======================
EPOCH:  8
time: 8.1s , loss:  1.1748321 , accuracy_:  0.56589675 , val_loss:  0.9476638 , val_acc_:  0.609375 , lr:  0.08
=======================
EPOCH:  9
time: 7.9s , loss:  1.0530715 , accuracy_:  0.5855978 , val_loss:  0.9900672 , val_acc_:  0.6269531 , lr:  0.08
=======================
EPOCH:  10
time: 7.8s , loss:  1.069793 , accuracy_:  0.5808424 , val_loss:  1.0800121 , val_acc_:  0.5800781 , lr:  0.08
=======================
EPOCH:  11
time: 7.6s , loss:  1.0723163 , accuracy_:  0.5944294 , val_loss:  1.369957 , val_acc_:  0.5566406 , lr:  0.08
=======================
EPOCH:  12
time: 7.6s , loss:  1.3123384 , accuracy_:  0.5784647 , val_loss:  0.9436419 , val_acc_:  0.65625 , lr:  0.08
=======================
EPOCH:  13
time: 7.7s , loss:  0.989202 , accuracy_:  0.58967394 , val_loss:  1.001816 , val_acc_:  0.57421875 , lr:  0.08
=======================
EPOCH:  14
time: 7.8s , loss:  1.0398531 , accuracy_:  0.5998641 , val_loss:  1.0300725 , val_acc_:  0.6171875 , lr:  0.07604999999999999
=======================
EPOCH:  15
time: 7.8s , loss:  0.93899727 , accuracy_:  0.6222826 , val_loss:  0.96907616 , val_acc_:  0.640625 , lr:  0.0722975
=======================
EPOCH:  16
time: 7.6s , loss:  1.0537727 , accuracy_:  0.61514944 , val_loss:  1.0858477 , val_acc_:  0.5644531 , lr:  0.06873262499999999
=======================
EPOCH:  17
time: 7.6s , loss:  0.94413257 , accuracy_:  0.625 , val_loss:  0.9053484 , val_acc_:  0.63671875 , lr:  0.06534599375
=======================
EPOCH:  18
time: 7.9s , loss:  0.90452623 , accuracy_:  0.6382473 , val_loss:  0.96852255 , val_acc_:  0.5859375 , lr:  0.06212869406249998
=======================
EPOCH:  19
time: 7.8s , loss:  0.8474844 , accuracy_:  0.6154891 , val_loss:  1.0920925 , val_acc_:  0.6484375 , lr:  0.05907225935937498
=======================
EPOCH:  20
time: 8.1s , loss:  1.0140264 , accuracy_:  0.6576087 , val_loss:  0.96260154 , val_acc_:  0.6621094 , lr:  0.05616864639140623
=======================
EPOCH:  21
time: 7.6s , loss:  0.93450826 , accuracy_:  0.64368206 , val_loss:  0.84104973 , val_acc_:  0.68359375 , lr:  0.05341021407183592
=======================
EPOCH:  22
time: 7.9s , loss:  0.8612748 , accuracy_:  0.65930706 , val_loss:  1.0461957 , val_acc_:  0.58203125 , lr:  0.05078970336824412
=======================
EPOCH:  23
time: 7.6s , loss:  0.9631493 , accuracy_:  0.6389266 , val_loss:  0.88524663 , val_acc_:  0.6933594 , lr:  0.048300218199831914
=======================
EPOCH:  24
time: 7.6s , loss:  0.8471291 , accuracy_:  0.68172556 , val_loss:  0.98489004 , val_acc_:  0.6230469 , lr:  0.04593520728984032
=======================
EPOCH:  25
time: 7.6s , loss:  0.8293517 , accuracy_:  0.65964675 , val_loss:  1.0098193 , val_acc_:  0.6191406 , lr:  0.0436884469253483
=======================
EPOCH:  26
time: 7.8s , loss:  0.74953043 , accuracy_:  0.6637228 , val_loss:  0.83294183 , val_acc_:  0.671875 , lr:  0.04155402457908088
=======================
EPOCH:  27
time: 7.6s , loss:  0.9532999 , accuracy_:  0.6858016 , val_loss:  0.86955905 , val_acc_:  0.6933594 , lr:  0.03952632335012683
=======================
EPOCH:  28
time: 7.8s , loss:  0.83798003 , accuracy_:  0.6779891 , val_loss:  0.82175803 , val_acc_:  0.6816406 , lr:  0.03760000718262049
=======================
EPOCH:  29
time: 7.8s , loss:  0.9849446 , accuracy_:  0.6824049 , val_loss:  0.896647 , val_acc_:  0.640625 , lr:  0.035770006823489464
=======================
EPOCH:  30
time: 7.8s , loss:  0.6697856 , accuracy_:  0.70516306 , val_loss:  0.77810496 , val_acc_:  0.7207031 , lr:  0.03403150648231499
=======================
EPOCH:  31
time: 7.8s , loss:  0.6558618 , accuracy_:  0.6925951 , val_loss:  0.8638859 , val_acc_:  0.703125 , lr:  0.03237993115819924
=======================
EPOCH:  32
time: 7.8s , loss:  0.83388436 , accuracy_:  0.7095788 , val_loss:  0.93534267 , val_acc_:  0.67578125 , lr:  0.030810934600289275
=======================
EPOCH:  33
time: 7.8s , loss:  0.80463845 , accuracy_:  0.6881794 , val_loss:  0.7605772 , val_acc_:  0.7167969 , lr:  0.02932038787027481
=======================
EPOCH:  34
time: 7.8s , loss:  0.85429716 , accuracy_:  0.7085598 , val_loss:  0.7597219 , val_acc_:  0.7324219 , lr:  0.02790436847676107
=======================
EPOCH:  35
time: 7.6s , loss:  0.9584191 , accuracy_:  0.7055027 , val_loss:  0.7615415 , val_acc_:  0.7011719 , lr:  0.026559150052923013
=======================
EPOCH:  36
time: 7.8s , loss:  0.75820315 , accuracy_:  0.7102581 , val_loss:  0.79112947 , val_acc_:  0.7050781 , lr:  0.025281192550276863
=======================
EPOCH:  37
time: 7.9s , loss:  0.6969924 , accuracy_:  0.7272419 , val_loss:  0.89511955 , val_acc_:  0.68359375 , lr:  0.02406713292276302
=======================
EPOCH:  38
time: 7.6s , loss:  0.7745343 , accuracy_:  0.70720106 , val_loss:  0.7736124 , val_acc_:  0.7011719 , lr:  0.02291377627662487
=======================
EPOCH:  39
time: 7.6s , loss:  0.8938252 , accuracy_:  0.7184103 , val_loss:  0.8276354 , val_acc_:  0.7128906 , lr:  0.021818087462793623
=======================
EPOCH:  40
time: 7.7s , loss:  0.62723714 , accuracy_:  0.7214674 , val_loss:  0.7376336 , val_acc_:  0.73828125 , lr:  0.02077718308965394
=======================
EPOCH:  41
time: 7.6s , loss:  0.72016454 , accuracy_:  0.717731 , val_loss:  0.7465435 , val_acc_:  0.7402344 , lr:  0.019788323935171243
=======================
EPOCH:  42
time: 7.8s , loss:  0.75606847 , accuracy_:  0.7194294 , val_loss:  0.7560054 , val_acc_:  0.734375 , lr:  0.01884890773841268
=======================
EPOCH:  43
time: 7.6s , loss:  0.6489076 , accuracy_:  0.72860056 , val_loss:  0.75105846 , val_acc_:  0.72265625 , lr:  0.017956462351492047
=======================
EPOCH:  44
time: 7.8s , loss:  0.67065036 , accuracy_:  0.7398098 , val_loss:  0.750391 , val_acc_:  0.7285156 , lr:  0.017108639233917443
=======================
EPOCH:  45
time: 7.8s , loss:  0.7675567 , accuracy_:  0.71501356 , val_loss:  0.73273534 , val_acc_:  0.7636719 , lr:  0.01630320727222157
=======================
EPOCH:  46
time: 7.7s , loss:  0.7654125 , accuracy_:  0.73029894 , val_loss:  0.81605816 , val_acc_:  0.7128906 , lr:  0.015538046908610489
=======================
EPOCH:  47
time: 7.6s , loss:  0.89121294 , accuracy_:  0.73063856 , val_loss:  0.76160526 , val_acc_:  0.7421875 , lr:  0.014811144563179963
=======================
EPOCH:  48
time: 7.8s , loss:  0.56929743 , accuracy_:  0.7411685 , val_loss:  0.68539953 , val_acc_:  0.7890625 , lr:  0.014120587335020966
=======================
EPOCH:  49
time: 7.6s , loss:  0.68999815 , accuracy_:  0.74524456 , val_loss:  0.7002586 , val_acc_:  0.74609375 , lr:  0.013464557968269918
=======================
EPOCH:  50
time: 7.6s , loss:  0.72345835 , accuracy_:  0.73777175 , val_loss:  0.75025487 , val_acc_:  0.7167969 , lr:  0.01284133006985642
=======================
EPOCH:  51
time: 7.8s , loss:  0.7644169 , accuracy_:  0.7493206 , val_loss:  0.7927157 , val_acc_:  0.7480469 , lr:  0.012249263566363598
=======================
EPOCH:  52
time: 7.8s , loss:  0.72519517 , accuracy_:  0.7571331 , val_loss:  0.84032404 , val_acc_:  0.68359375 , lr:  0.01168680038804542
=======================
EPOCH:  53
time: 7.6s , loss:  0.7161773 , accuracy_:  0.7425272 , val_loss:  0.73635113 , val_acc_:  0.75 , lr:  0.011152460368643147
=======================
EPOCH:  54
time: 7.8s , loss:  0.6453468 , accuracy_:  0.74422556 , val_loss:  0.7183797 , val_acc_:  0.7480469 , lr:  0.01064483735021099
=======================
EPOCH:  55
time: 7.6s , loss:  0.62045324 , accuracy_:  0.76256794 , val_loss:  0.7136475 , val_acc_:  0.76953125 , lr:  0.010162595482700439
=======================
EPOCH:  56
time: 7.8s , loss:  0.6213637 , accuracy_:  0.7717391 , val_loss:  0.7002554 , val_acc_:  0.7636719 , lr:  0.009704465708565417
=======================
EPOCH:  57
time: 7.6s , loss:  0.75231487 , accuracy_:  0.76222825 , val_loss:  0.6813812 , val_acc_:  0.78515625 , lr:  0.009269242423137147
=======================
EPOCH:  58
time: 7.8s , loss:  0.71134186 , accuracy_:  0.7472826 , val_loss:  0.73741865 , val_acc_:  0.7480469 , lr:  0.008855780301980289
=======================
EPOCH:  59
time: 7.8s , loss:  0.6392013 , accuracy_:  0.7578125 , val_loss:  0.71753514 , val_acc_:  0.75390625 , lr:  0.008462991286881274
=======================
EPOCH:  60
time: 7.6s , loss:  0.69161326 , accuracy_:  0.7571331 , val_loss:  0.67235625 , val_acc_:  0.76171875 , lr:  0.008089841722537211
TOTAL TRAINING TIME: 498.8s
In [13]:
print("Detailed training loss:")
plt.plot(train_losses)
plt.show()
Detailed training loss:

Predictions (not distributed)

In [14]:
# randomize the input so that you can execute multiple times to change results
permutation = np.random.permutation(8*20)
some_flowers, some_labels = (some_flowers[permutation], some_labels[permutation])

predictions = model.predict(some_flowers, batch_size=16)
  
print(np.array(CLASSES)[np.argmax(predictions, axis=-1)].tolist())
WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/engine/training.py:2975: StrategyBase.unwrap (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.
Instructions for updating:
use `experimental_local_results` instead.
WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/engine/training.py:2975: StrategyBase.unwrap (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.
Instructions for updating:
use `experimental_local_results` instead.
['daisy', 'tulips', 'dandelion', 'tulips', 'dandelion', 'daisy', 'tulips', 'tulips', 'roses', 'roses', 'daisy', 'tulips', 'roses', 'sunflowers', 'daisy', 'tulips', 'daisy', 'tulips', 'dandelion', 'roses', 'roses', 'sunflowers', 'sunflowers', 'daisy', 'tulips', 'daisy', 'tulips', 'tulips', 'daisy', 'dandelion', 'sunflowers', 'dandelion', 'daisy', 'roses', 'dandelion', 'tulips', 'tulips', 'daisy', 'tulips', 'roses', 'tulips', 'dandelion', 'sunflowers', 'daisy', 'daisy', 'daisy', 'daisy', 'roses', 'tulips', 'roses', 'daisy', 'daisy', 'sunflowers', 'tulips', 'tulips', 'dandelion', 'dandelion', 'daisy', 'tulips', 'dandelion', 'daisy', 'sunflowers', 'sunflowers', 'roses', 'dandelion', 'dandelion', 'roses', 'dandelion', 'tulips', 'tulips', 'tulips', 'roses', 'dandelion', 'roses', 'tulips', 'dandelion', 'sunflowers', 'daisy', 'dandelion', 'dandelion', 'sunflowers', 'dandelion', 'sunflowers', 'dandelion', 'dandelion', 'tulips', 'dandelion', 'tulips', 'dandelion', 'sunflowers', 'dandelion', 'dandelion', 'dandelion', 'daisy', 'roses', 'tulips', 'sunflowers', 'roses', 'dandelion', 'dandelion', 'daisy', 'dandelion', 'sunflowers', 'roses', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'tulips', 'roses', 'dandelion', 'sunflowers', 'dandelion', 'tulips', 'dandelion', 'dandelion', 'roses', 'roses', 'dandelion', 'dandelion', 'dandelion', 'roses', 'dandelion', 'tulips', 'daisy', 'dandelion', 'tulips', 'dandelion', 'tulips', 'tulips', 'dandelion', 'daisy', 'sunflowers', 'roses', 'sunflowers', 'sunflowers', 'daisy', 'sunflowers', 'daisy', 'daisy', 'dandelion', 'sunflowers', 'dandelion', 'roses', 'roses', 'tulips', 'roses', 'daisy', 'sunflowers', 'sunflowers', 'dandelion', 'dandelion', 'dandelion', 'dandelion', 'roses', 'dandelion', 'roses', 'dandelion', 'roses', 'roses']
In [15]:
display_9_images_with_predictions(some_flowers, predictions, some_labels)
In [ ]: