In [1]:
import tensorflow as tf
print(tf.__version__)
2.2.0
In [2]:
import numpy as np

np.random.seed(2022)
In [3]:
# Windows で music21 を使うための余分なコード
import os
from music21 import environment

if os.name == 'nt':
    us = environment.UserSettings()
    us['musescoreDirectPNGPath']='C:/Program Files/MuseScore 3/bin/MuseScore3.exe'
    us['musicxmlPath']='C:/Program Files/MuseScore 3/bin/MuseScore3.exe'
In [4]:
data_dir = 'D:\\data\\gdl_book14\\MIDI\\jsbach'
In [5]:
if os.name == 'nt':
    LS_CMD = 'dir /s'
else:
    LS_CMD = 'ls -lR'
    
!{LS_CMD} {data_dir}
 ドライブ D のボリューム ラベルがありません。
 ボリューム シリアル番号は 606C-349E です

 D:\data\gdl_book14\MIDI\jsbach のディレクトリ

2021/12/22  08:37    <DIR>          .
2021/12/22  08:37    <DIR>          ..
2021/12/11  10:53    <DIR>          bwv1007
2021/12/22  08:37             7,206 bwv1007.zip
2021/12/11  10:53    <DIR>          bwv1008
2021/12/22  08:37             8,021 bwv1008.zip
2021/12/11  10:53    <DIR>          bwv1009
2021/12/22  08:37             8,783 bwv1009.zip
2021/12/11  10:53    <DIR>          bwv1010
2021/12/22  08:37             8,622 bwv1010.zip
2021/12/11  10:53    <DIR>          bwv1011
2021/12/22  08:37             8,973 bwv1011.zip
2021/12/11  10:53    <DIR>          bwv1012
2021/12/22  08:37            10,242 bwv1012.zip
               6 個のファイル              51,847 バイト

 D:\data\gdl_book14\MIDI\jsbach\bwv1007 のディレクトリ

2021/12/11  10:53    <DIR>          .
2021/12/11  10:53    <DIR>          ..
1997/02/22  21:59             4,922 cs1-1pre.mid
1997/02/22  21:59             7,088 cs1-2all.mid
1997/02/22  22:00             6,159 cs1-3cou.mid
1997/02/22  22:25             3,209 cs1-4sar.mid
1997/02/22  23:13             5,704 cs1-5men.mid
1997/02/22  22:01             3,966 cs1-6gig.mid
               6 個のファイル              31,048 バイト

 D:\data\gdl_book14\MIDI\jsbach\bwv1008 のディレクトリ

2021/12/11  10:53    <DIR>          .
2021/12/11  10:53    <DIR>          ..
1997/02/22  22:03             4,957 cs2-1pre.mid
1997/02/22  22:03             5,765 cs2-2all.mid
1997/02/22  22:03             5,508 cs2-3cou.mid
1997/02/22  22:24             4,049 cs2-4sar.mid
1997/02/22  23:14             6,014 cs2-5men.mid
1997/02/22  22:27             6,183 cs2-6gig.mid
               6 個のファイル              32,476 バイト

 D:\data\gdl_book14\MIDI\jsbach\bwv1009 のディレクトリ

2021/12/11  10:53    <DIR>          .
2021/12/11  10:53    <DIR>          ..
1997/02/22  22:27             7,331 cs3-1pre.mid
1997/02/22  22:28             7,097 cs3-2all.mid
1997/02/22  22:27             8,019 cs3-3cou.mid
1997/02/22  22:28             4,021 cs3-4sar.mid
1997/02/22  22:28             7,484 cs3-5bou.mid
1997/02/22  22:34             7,850 cs3-6gig.mid
               6 個のファイル              41,802 バイト

 D:\data\gdl_book14\MIDI\jsbach\bwv1010 のディレクトリ

2021/12/11  10:53    <DIR>          .
2021/12/11  10:53    <DIR>          ..
1997/02/22  22:29             6,701 cs4-1pre.mid
1997/02/22  22:10             7,718 cs4-2all.mid
1997/02/22  22:29             7,528 cs4-3cou.mid
1997/02/22  22:29             3,880 cs4-4sar.mid
1997/02/22  22:11             9,968 cs4-5bou.mid
1997/02/22  22:12             7,770 cs4-6gig.mid
               6 個のファイル              43,565 バイト

 D:\data\gdl_book14\MIDI\jsbach\bwv1011 のディレクトリ

2021/12/11  10:53    <DIR>          .
2021/12/11  10:53    <DIR>          ..
1997/02/22  22:30             9,777 cs5-1pre.mid
1997/02/22  22:37             6,091 cs5-2all.mid
1997/02/22  22:38             4,673 cs5-3cou.mid
1997/02/22  22:14             2,538 cs5-4sar.mid
1997/02/22  22:39            10,660 cs5-5gav.mid
1997/02/22  22:31             3,900 cs5-6gig.mid
               6 個のファイル              37,639 バイト

 D:\data\gdl_book14\MIDI\jsbach\bwv1012 のディレクトリ

2021/12/11  10:53    <DIR>          .
2021/12/11  10:53    <DIR>          ..
1997/02/22  22:15            10,451 cs6-1pre.mid
1997/02/22  22:16             8,236 cs6-2all.mid
1997/02/22  22:31             9,135 cs6-3cou.mid
1997/02/22  22:49             5,424 cs6-4sar.mid
1997/02/22  22:18             9,907 cs6-5gav.mid
1997/02/22  22:32             9,656 cs6-6gig.mid
               6 個のファイル              52,809 バイト

     ファイルの総数:
              42 個のファイル             291,186 バイト
              20 個のディレクトリ  1,749,686,632,448 バイトの空き領域
In [6]:
import os
import glob

midi_paths = glob.glob(os.path.join(data_dir, '*/*.mid'))
In [7]:
print(midi_paths[:2])
['D:\\data\\gdl_book14\\MIDI\\jsbach\\bwv1007\\cs1-1pre.mid', 'D:\\data\\gdl_book14\\MIDI\\jsbach\\bwv1007\\cs1-2all.mid']
In [8]:
data_filepath = 'run/music_params.pkl'
In [9]:
if os.name == 'nt':
    CAT_CMD = 'type'
else:
    CAT_CMD = 'cat'
    
! {CAT_CMD} {os.path.join('nw', 'LSTMMusic.py')}
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import os
import glob
import pickle as pkl
import datetime

import music21
import itertools

###################################################################################
# DataLoader
###################################################################################
class ScoreDataset(tf.keras.utils.Sequence):
    def __init__(self, save_path='music_param.pkl', midi_paths=None, seq_len=32):
        if not midi_paths is None:
            self._build(midi_paths, seq_len)
            self.save(save_path)
        else:
            self.load(save_path)

        
    def _build(self, midi_paths, seq_len=32):
        self.seq_len = seq_len
        
        self.notes_list, self.durations_list = ScoreDataset.makeMusicData(midi_paths)

        notes_set = sorted(set(itertools.chain.from_iterable(self.notes_list))) # flatten 2D -> 1D, Unique, Sort
        durations_set = sorted(set(itertools.chain.from_iterable(self.durations_list)))

        self.note_to_index, self.index_to_note = ScoreDataset.createLookups(notes_set)
        self.duration_to_index, self.index_to_duration = ScoreDataset.createLookups(durations_set)

        self.c_notes = len(self.note_to_index)
        self.c_durations = len(self.duration_to_index)

        self.notes_index_list = ScoreDataset.convertIndex(self.notes_list, self.note_to_index)
        self.durations_index_list = ScoreDataset.convertIndex(self.durations_list, self.duration_to_index)

        self.n_music = len(self.notes_list)
        self.index = 0

        self._build_tbl()       
        

    @staticmethod
    def extractMidi(midi_path):
        notes, durations = [], []
        score = music21.converter.parse(midi_path).chordify()
        for element in score.flat:
            if isinstance(element, music21.note.Note): # Note
                if element.isRest:  # Rest
                    notes.append(str(element.name))   # no pitch, then name only
                else: # note with pitch
                    notes.append(str(element.nameWithOctave))
                durations.append(element.duration.quaterLength) # 1/4 unit
        
            if isinstance(element, music21.chord.Chord): # chord contains multiple notes
                notes.append('.'.join(n.nameWithOctave for n in element.pitches)) # connect with '.'
                durations.append(element.duration.quarterLength) # 1/4 unit
            
        return notes, durations

    
    @staticmethod
    # [notes1, ..., notesN], [durations1, ..., durationsN]
    def makeMusicData(midi_paths):
        notes_list, durations_list = [], []
        for path in midi_paths:
            notes, durations = ScoreDataset.extractMidi(path)  # notes, durations
            notes_list.append(notes)
            durations_list.append(durations)
    
        return notes_list, durations_list


    @staticmethod
    def createLookups(names):  # Lookup Table
        element_to_index = dict((element, idx) for idx, element in enumerate(names))
        index_to_element = dict((idx, element) for idx, element in enumerate(names))
        return element_to_index, index_to_element


    @staticmethod
    def convertIndex(data, element_to_index):
        return [ [ element_to_index[element] for element in x] for x in data]

    
    def getMidiStream(self, g_notes, g_durations):  # [note_index, ...], [duration_index, ...]
        midi_stream = music21.stream.Stream()
        for note_idx, duration_idx in zip(g_notes, g_durations):
            note = self.index_to_note[note_idx]
            duration = self.index_to_duration[duration_idx]
            if ('.' in note): # chord
                notes_in_chord = note.split('.')
                chord_notes = []
                for n_ in notes_in_chord:
                    new_note = music21.note.Note(n_)
                    new_note.duration = music21.duration.Duration(duration)
                    new_note.storeInstrument = music21.instrument.Violoncello()
                    chord_notes.append(new_note)
                new_chord = music21.chord.Chord(chord_notes)
                midi_stream.append(new_chord)
            elif note == 'rest':
                new_note = music21.note.Rest()
                new_note.duration = music21.duration.Duration(duration)
                new_note.storedInstrument = music21.instrument.Violoncello()
                midi_stream.append(new_note)
            else:
                new_note = music21.note.Note(note)
                new_note.duration = music21.duration.Duration(duration)
                new_note.storedInstrument = music21.instrument.Violoncello()
                midi_stream.append(new_note)

        return midi_stream
        
        
    def _build_tbl(self):
        a = [ len(x)-self.seq_len for x in self.notes_list ]  # [int, int, ...]
        for i in range(1, len(a)):   # cumulative frequency of data
            a[i] = a[i-1] + a[i]
        self.cumulative_freq = a
        #print(f'cumulative_freq: {self.cumulative_freq}')

            
    def searchTbl(self, index):
        index = index % self.__len__()
        low = 0
        high = self.n_music - 1
        for i in range(self.n_music):
            mid = (low + high) // 2
            #print(f'{i}/{self.n_music}: {high} {low} {mid} {index}')
            if self.cumulative_freq[mid] > index:
                if mid == 0 or self.cumulative_freq[mid-1] <= index:
                    return mid
                high = mid - 1
            else:
                low = mid + 1

                
    def __len__(self):
        return self.cumulative_freq[-1]

    
    def __getitem__(self, index):
        if isinstance(index, slice):
            start, stop, step = index.indices(self.__len__())
            if start == None: start = 0
            if stop == None: stop = self.__len__()
            if step == None:
                if start < stop:
                    step = 1
                elif start > stop:
                    step = -1
                else:
                    step = 0                    
            return self.__getitemList__(range(start, stop, step))
        
        elif isinstance(index, int):
            return self.__getitemInt__(index)
        
        else:
            return self.__getitemList__(index)
        
        
    def __getitemList__(self, indices):
        x_notes, x_durations, y_notes, y_durations = [], [], [], []
        for i in indices:
            [x_note, x_duration], [y_note, y_duration] = self.__getitemInt__(i)
            x_notes.append(x_note)
            x_durations.append(x_duration)
            y_notes.append(y_note)
            y_durations.append(y_duration)

        return (x_notes, x_durations), (y_notes, y_durations)
        

        
    def __getitemInt__(self, index):
        index = index % self.__len__()
        #print(f'index = {index} {self.__len__()}')
        tbl_idx = self.searchTbl(index)
        #print(f'tbl_idx = {tbl_idx}')
        tgt = index
        if (tbl_idx > 0):
            tgt -= self.cumulative_freq[tbl_idx - 1]
        #print(f'tgt = {tgt}')
        
        x_note = self.notes_index_list[tbl_idx][tgt: (tgt + self.seq_len)]
        y_note = self.notes_index_list[tbl_idx][tgt + self.seq_len]
        x_duration = self.durations_index_list[tbl_idx][tgt: (tgt + self.seq_len)]
        y_duration = self.durations_index_list[tbl_idx][tgt + self.seq_len]
        
        #ohv_y_note = tf.keras.utils.to_categorical(y_note, self.c_notes)
        #ohv_y_duration = tf.keras.utils.to_categorical(y_duration, self.c_durations)
        
        return (x_note, x_duration), (y_note, y_duration)

    
    def __next__(self):
        self.index += 1
        return self.__getitem__(self.index-1)


    def save(self, filepath):
        dpath, fname = os.path.split(filepath)
        if not os.path.exists(dpath):
            os.makedirs(dpath)

        with open(filepath, 'wb') as f:
            pkl.dump([
                self.seq_len,
                self.notes_list,
                self.durations_list,
                self.note_to_index,
                self.index_to_note,
                self.duration_to_index,
                self.index_to_duration,
                self.c_notes,
                self.c_durations,
                self.notes_index_list,
                self.durations_index_list,
                self.n_music,
                self.index,
                self.cumulative_freq
            ], f)
    

    def load(self, filepath):
        with open(filepath, 'rb') as f:
            params = pkl.load(f)
            
        [
            self.seq_len,
            self.notes_list,
            self.durations_list,
            self.note_to_index,
            self.index_to_note,
            self.duration_to_index,
            self.index_to_duration,
            self.c_notes,
            self.c_durations,
            self.notes_index_list,
            self.durations_index_list,
            self.n_music,
            self.index,
            self.cumulative_freq
        ] = params
    

###################################################################################
# Model
###################################################################################

class LSTMMusic():
    def __init__(self,
                 c_notes,
                 c_durations,
                 seq_len = 32,
                 optimizer='adam',
                 learning_rate = 0.001,
                 embed_size = 100,
                 rnn_units = 256,
                 use_attention = True,
                 epochs = 0,
                 losses = [],
                 n_losses = [],
                 d_losses = [],
                 val_losses = [],
                 val_n_losses = [],
                 val_d_losses = []
                 ):
        self.c_notes = c_notes
        self.c_durations = c_durations
        self.seq_len = seq_len
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.embed_size = embed_size
        self.rnn_units = rnn_units
        self.use_attention = use_attention
        self.epochs = epochs
        self.losses = losses
        self.n_losses = n_losses
        self.d_losses = d_losses
        self.val_losses = val_losses
        self.val_n_losses = val_n_losses
        self.val_d_losses = val_d_losses

        self.model, self.att_model = self._create_network(c_notes, c_durations, embed_size, rnn_units, use_attention)
        self.cce1 = tf.keras.losses.CategoricalCrossentropy(),
        self.cce2 = tf.keras.losses.CategoricalCrossentropy(),

        
    def _create_network(
            self,
            n_notes, 
            n_durations, 
            embed_size, 
            rnn_units, 
            use_attention
    ):
        notes_in = tf.keras.layers.Input(shape=(None,))
        durations_in = tf.keras.layers.Input(shape=(None,))
    
        x1 = tf.keras.layers.Embedding(n_notes, embed_size)(notes_in)
        x2 = tf.keras.layers.Embedding(n_durations, embed_size)(durations_in)
    
        x = tf.keras.layers.Concatenate()([x1, x2])
    
        x = tf.keras.layers.LSTM(rnn_units, return_sequences=True)(x)
        # x = tf.keras.layers.Dropout(0.2)(x)
        
        if use_attention:
            x = tf.keras.layers.LSTM(rnn_units, return_sequences=True)(x)
            # x = tf.keras.layers.Dropout(0.2)(x)
            
            e = tf.keras.layers.Dense(1, activation='tanh')(x)
            e = tf.keras.layers.Reshape([-1])(e)   # batch_size * N 
            alpha = tf.keras.layers.Activation('softmax')(e)
            
            alpha_repeated = tf.keras.layers.Permute([2,1])(tf.keras.layers.RepeatVector(rnn_units)(alpha))
        
            c = tf.keras.layers.Multiply()([x, alpha_repeated])
            c = tf.keras.layers.Lambda(lambda xin: tf.keras.backend.sum(xin, axis=1), output_shape=(rnn_units,))(c)
        
        else:
        
            c = tf.keras.layers.LSTM(rnn_units)(x)
            #c = tf.keras.layers.Dropout(0.2)(c)
    
        notes_out = tf.keras.layers.Dense(n_notes, activation='softmax', name='pitch')(c)
        durations_out = tf.keras.layers.Dense(n_durations, activation='softmax', name='duration')(c)
    
        model = tf.keras.models.Model([notes_in, durations_in], [notes_out, durations_out])
    
        if use_attention:
            att_model = tf.keras.models.Model([notes_in, durations_in], alpha)
        else:
            att_model = None
            
        return model, att_model


    def get_opti(self, learning_rate):
        if self.optimizer == 'adam':
            opti = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1 = 0.5)
        elif self.optimizer == 'rmsprop':
            opti = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)
        else:
            opti = tf.keras.optimizers.Adam(learning_rate=learning_rate)
        return opti


    def _compile(self):
        opti = self.get_opti(self.learning_rate)
        self.model.compile(
            loss=[
                self.cce1,
                self.cce2
            ],
            optimizer = tf.keras.optimizers.RMSprop(learning_rate = self.learning_rate)
        )

        
    def train_with_fit(
            self,
            xs,   # (x_notes,x_durations)
            ys,   # (y_notes, y_durations)
            epochs=1,
            batch_size=32,
            run_folder='run/',
            shuffle=False,
            validation_data = None
    ):
        history = self.model.fit(
            xs,
            ys,
            initial_epoch = self.epochs,
            epochs = epochs,
            batch_size = batch_size,
            shuffle = shuffle,
            validation_data = validation_data
        )
        self.epochs = epochs

        h = history.history
        self.losses += h['loss']
        self.n_losses += h['pitch_loss']
        self.d_losses += h['duration_loss']
        
        if not validation_data is None:
            self.val_losses += h['val_loss']
            self.val_n_losses += h['val_pitch_loss']
            self.val_d_losses += h['val_duration_loss']
            
        self.save(run_folder)
        self.save(run_folder, self.epochs)

        if validation_data is None:
            return self.losses, self.n_losses, self.d_losses
        else:
            return self.losses, self.n_losses, self.d_losses, self.val_losses, self.val_n_losses, self.val_d_losses,


    @tf.function
    def loss_fn(self, y_notes, y_durations, p_notes, p_durations):
        #print(y_notes.shape, p_notes.shape, y_durations.shape, p_durations.shape)
        #y_notes = np.array(y_notes, dtype='float32')
        #p_notes = np.array(p_notes, dtype='float32')
        n_loss = tf.keras.losses.CategoricalCrossentropy()(y_notes, p_notes)
        d_loss = tf.keras.losses.CategoricalCrossentropy()(y_durations, p_durations)
        #n_loss = self.cce1(y_notes, p_notes)
        #d_loss = self.cce2(y_durations, p_durations)
        loss = tf.add(n_loss, d_loss)
        return loss, n_loss, d_loss


    @tf.function
    def train_step(
            self,
            x_notes, 
            x_durations, 
            y_notes, 
            y_durations, 
            optimizer
    ):
        with tf.GradientTape() as tape:
            p_notes, p_durations = self.model([x_notes, x_durations])
            loss, note_loss, duration_loss = self.loss_fn(y_notes, y_durations, p_notes, p_durations)
        variables = self.model.trainable_variables
        gradients = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(gradients, variables))
        return loss, note_loss, duration_loss

    
    def train(
            self,
            xs,
            ys,
            epochs=1, 
            batch_size=32, 
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
            shuffle=False,
            run_folder='run/',
            print_step_interval = 100,
            save_epoch_interval = 100,
            validation_data = None
    ):
        start_time = datetime.datetime.now()

        x_notes_all, x_durations_all = xs
        y_notes_all, y_durations_all = ys

        steps = len(x_notes_all) // batch_size

        for epoch in range(self.epochs, epochs):
            indices = tf.range(len(x_notes_all), dtype=tf.int32)
            if shuffle:
                indices = tf.random.shuffle(indices)

            step_losses, step_n_losses, step_d_losses = [], [], []
            
            for step in range(steps):
                start = batch_size * step
                end = start + batch_size
                
                idxs = indices[start:end]
                x_notes = x_notes_all[idxs]
                x_durations = x_durations_all[idxs]
                y_notes = np.array(y_notes_all[idxs])
                y_durations = np.array(y_durations_all[idxs])
                
                n_ = np.array(x_notes, dtype='float32')
                d_ = np.array(x_durations, dtype='float32')
                         
                batch_loss, batch_n_loss, batch_d_loss = self.train_step(
                    n_, 
                    d_, 
                    y_notes, 
                    y_durations,
                    optimizer
                )

                step_losses.append(np.mean(batch_loss))
                step_n_losses.append(np.mean(batch_n_loss))
                step_d_losses.append(np.mean(batch_d_loss))
                
                elapsed_time = datetime.datetime.now() - start_time
                if (step+1) % print_step_interval == 0:
                    print(f'{epoch+1}/{epochs} {step+1}/{steps} loss {step_losses[-1]:.3f} pitch_loss {step_n_losses[-1]:.3f} duration_loss {step_d_losses[-1]:.3f} {elapsed_time}')


            epoch_loss = np.mean(step_losses)
            epoch_n_loss = np.mean(step_n_losses)
            epoch_d_loss = np.mean(step_d_losses)

            self.losses.append(epoch_loss)
            self.n_losses.append(epoch_n_loss)
            self.d_losses.append(epoch_d_loss)

            val_str = ''
            if not validation_data is None:
                (val_x_notes, val_x_durations), (val_y_notes, val_y_durations) = validation_data
                p_notes, p_durations = self.model([val_x_notes, val_x_durations])
                val_loss, val_n_loss, val_d_loss = self.loss_fn(
                    val_y_notes,
                    val_y_durations,
                    p_notes,
                    p_durations
                )
                val_loss = np.mean(val_loss)
                val_n_loss = np.mean(val_n_loss)
                val_d_loss = np.mean(val_d_loss)

                self.val_losses.append(val_loss)
                self.val_n_losses.append(val_n_loss)
                self.val_d_losses.append(val_d_loss)

                val_str = f'val_loss {val_loss:.3f} val_pitch_loss {val_n_loss:.3f} val_duration_loss {val_d_loss:.3f}'
                
            self.epochs += 1

            elapsed_time = datetime.datetime.now() - start_time
            print(f'{self.epochs}/{epochs} loss {epoch_loss:.3f} pitch_loss {epoch_n_loss:.3f} duration_loss {epoch_d_loss:.3f} {val_str} {elapsed_time}')
            
            if self.epochs % save_epoch_interval == 0:
                self.save(run_folder)
                self.save(run_folder, self.epochs)

        self.save(run_folder)
        self.save(run_folder, self.epochs)
        
        return self.losses, self.n_losses, self.d_losses, self.val_losses, self.val_n_losses, self.val_d_losses


    @staticmethod
    def sample_with_temperature(preds, temperature):
        if temperature == 0:
            return np.argmax(preds)
        else:
            preds = np.log(preds) / temperature
            exp_preds = np.exp(preds)
            preds = exp_preds / np.sum(exp_preds)
            return np.random.choice(len(preds), p=preds)
        

    def generate(self, s_notes, s_durations, count=64, note_temperature = 0.5, duration_temperature = 0.5):
        x_notes = np.array([s_notes], dtype='float32')
        x_durations = np.array([s_durations], dtype='float32')
            
        g_notes, g_durations = [], []
        for i in range(count):
            p_notes, p_durations = self.model([x_notes, x_durations])
            note = LSTMMusic.sample_with_temperature(p_notes[0], note_temperature)
            duration = LSTMMusic.sample_with_temperature(p_durations[0], duration_temperature)
            g_notes.append(note)
            g_durations.append(duration)

            x_notes = np.roll(x_notes, -1)
            x_durations = np.roll(x_durations, -1)
            x_notes[0,-1] = note
            x_durations[0,-1] = duration

        return g_notes, g_durations  # [note_index, ...], [duration_index, ...]


    def save(self, folder, epoch=None):
        self.save_params(folder, epoch)
        self.save_weights(folder, epoch)


    @staticmethod
    def load(folder, epoch=None):
        params = LSTMMusic.load_params(folder, epoch)
        music = LSTMMusic(*params)
        music.load_weights(folder, epoch)
        return music


    def save_weights(self, run_folder, epoch=None):
        if epoch is None:
            self.save_model_weights(self.model, os.path.join(run_folder, 'weights/weights.h5'))    
            self.save_model_weights(self.att_model, os.path.join(run_folder, 'weights/weights_att.h5'))    
        else:
            self.save_model_weights(self.model, os.path.join(run_folder, f'weights/weights_{epoch}.h5'))
            self.save_model_weights(self.att_model, os.path.join(run_folder, f'weights/weights_att_{epoch}.h5'))


    def load_weights(self, run_folder, epoch=None):
        if epoch is None:
            self.load_model_weights(self.model, os.path.join(run_folder, 'weights/weights.h5'))
            self.load_model_weights(self.att_model, os.path.join(run_folder, 'weights/weights_att.h5'))
        else:
            self.load_model_weights(self.model, os.path.join(run_folder, f'weights/weights_{epoch}.h5'))
            self.load_model_weights(self.att_model, os.path.join(run_folder, f'weights/weights_att_{epoch}.h5'))


    def save_model_weights(self, model, filepath):
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        model.save_weights(filepath)


    def load_model_weights(self, model, filepath):
        model.load_weights(filepath)


    def save_params(self, folder, epoch=None):
        if not os.path.exists(folder):
            os.makedirs(folder)
            
        if epoch is None:
            filepath = os.path.join(folder, 'params.pkl')
        else:
            filepath = os.path.join(folder, f'params_{epoch}.pkl')

        with open(filepath, 'wb') as f:
            pkl.dump([
                self.c_notes,
                self.c_durations,
                self.seq_len,
                self.optimizer,
                self.learning_rate,
                self.embed_size,
                self.rnn_units,
                self.use_attention,
                self.epochs,
                self.losses,
                self.n_losses,
                self.d_losses,
                self.val_losses,
                self.val_n_losses,
                self.val_d_losses
            ], f)


    @staticmethod
    def load_params(folder, epoch=None):
        if epoch is None:
            filepath = os.path.join(folder, 'params.pkl')
        else:
            filepath = os.path.join(folder, f'params_{epoch}.pkl')

        with open(filepath, 'rb') as f:
            params = pkl.load(f)
        return params


    @staticmethod
    def plot_history(vals, labels):
        colors = ['red', 'blue', 'green', 'black', 'orange', 'pink', 'purple', 'olive', 'cyan']
        n = len(vals)
        fig, ax = plt.subplots(1, 1, figsize=(12,6))
        for i in range(n):
            ax.plot(vals[i], c=colors[i], label=labels[i])
        ax.legend(loc='upper right')
        ax.set_xlabel('epochs')

        plt.show()

        
In [10]:
import sys
sys.path.append('./nw')

from LSTMMusic import ScoreDataset

if not os.path.exists(data_filepath):
    data_seq = ScoreDataset(save_path=data_filepath, midi_paths=midi_paths, seq_len=32)
else:
    data_seq = ScoreDataset(save_path=data_filepath)
In [11]:
import sys
sys.path.append('./nw')

from LSTMMusic import LSTMMusic

lstm_music = LSTMMusic(
    data_seq.c_notes,
    data_seq.c_durations
)
In [22]:
lstm_music.model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, None, 100)    46000       input_1[0][0]                    
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, None, 100)    1800        input_2[0][0]                    
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, None, 200)    0           embedding[0][0]                  
                                                                 embedding_1[0][0]                
__________________________________________________________________________________________________
lstm (LSTM)                     (None, None, 256)    467968      concatenate[0][0]                
__________________________________________________________________________________________________
lstm_1 (LSTM)                   (None, None, 256)    525312      lstm[0][0]                       
__________________________________________________________________________________________________
dense (Dense)                   (None, None, 1)      257         lstm_1[0][0]                     
__________________________________________________________________________________________________
reshape (Reshape)               (None, None)         0           dense[0][0]                      
__________________________________________________________________________________________________
activation (Activation)         (None, None)         0           reshape[0][0]                    
__________________________________________________________________________________________________
repeat_vector (RepeatVector)    (None, 256, None)    0           activation[0][0]                 
__________________________________________________________________________________________________
permute (Permute)               (None, None, 256)    0           repeat_vector[0][0]              
__________________________________________________________________________________________________
multiply (Multiply)             (None, None, 256)    0           lstm_1[0][0]                     
                                                                 permute[0][0]                    
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 256)          0           multiply[0][0]                   
__________________________________________________________________________________________________
pitch (Dense)                   (None, 460)          118220      lambda[0][0]                     
__________________________________________________________________________________________________
duration (Dense)                (None, 18)           4626        lambda[0][0]                     
==================================================================================================
Total params: 1,164,183
Trainable params: 1,164,183
Non-trainable params: 0
__________________________________________________________________________________________________
In [23]:
lstm_music.att_model.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, None, 100)    46000       input_1[0][0]                    
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, None, 100)    1800        input_2[0][0]                    
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, None, 200)    0           embedding[0][0]                  
                                                                 embedding_1[0][0]                
__________________________________________________________________________________________________
lstm (LSTM)                     (None, None, 256)    467968      concatenate[0][0]                
__________________________________________________________________________________________________
lstm_1 (LSTM)                   (None, None, 256)    525312      lstm[0][0]                       
__________________________________________________________________________________________________
dense (Dense)                   (None, None, 1)      257         lstm_1[0][0]                     
__________________________________________________________________________________________________
reshape (Reshape)               (None, None)         0           dense[0][0]                      
__________________________________________________________________________________________________
activation (Activation)         (None, None)         0           reshape[0][0]                    
==================================================================================================
Total params: 1,041,337
Trainable params: 1,041,337
Non-trainable params: 0
__________________________________________________________________________________________________
In [12]:
# split data into train and val data.
import numpy as np

val_split = 0.05

N_DATA = len(data_seq)
N_VAL = int(N_DATA * val_split)

arr = np.arange(N_DATA)
np.random.shuffle(arr)

train_indices = sorted(arr[:-N_VAL])
val_indices = sorted(arr[-N_VAL:])

(train_x_notes, train_x_durations), (train_y_notes, train_y_durations) = data_seq[train_indices]
(val_x_notes, val_x_durations), (val_y_notes, val_y_durations) = data_seq[val_indices]

print(len(train_indices), len(val_indices))
25184 1325
In [13]:
print(train_indices[:20])
print(val_indices[:20])
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20]
[16, 30, 38, 48, 57, 61, 65, 121, 150, 155, 206, 211, 213, 309, 311, 339, 352, 389, 397, 429]
In [14]:
train_x_notes = np.array(train_x_notes)
train_x_durations = np.array(train_x_durations)

val_x_notes = np.array(val_x_notes)
val_x_durations = np.array(val_x_durations)
In [15]:
import tensorflow as tf

train_y_notes_ohv = tf.keras.utils.to_categorical(train_y_notes, data_seq.c_notes)
train_y_durations_ohv = tf.keras.utils.to_categorical(train_y_durations, data_seq.c_durations)

val_y_notes_ohv = tf.keras.utils.to_categorical(val_y_notes, data_seq.c_notes)
val_y_durations_ohv = tf.keras.utils.to_categorical(val_y_durations, data_seq.c_durations)
In [16]:
save_path = 'run'
In [17]:
lstm_music._compile()

h = lstm_music.train_with_fit(
        (train_x_notes, train_x_durations), 
        (train_y_notes_ohv, train_y_durations_ohv),
        epochs=1,
        batch_size=32,
        run_folder = save_path,
        validation_data = ((val_x_notes, val_x_durations), (val_y_notes_ohv, val_y_durations_ohv))
)
787/787 [==============================] - 7s 9ms/step - loss: 4.3482 - pitch_loss: 3.5334 - duration_loss: 0.8148 - val_loss: 6.1381 - val_pitch_loss: 4.5460 - val_duration_loss: 1.5922
In [18]:
h = lstm_music.train(
        (train_x_notes, train_x_durations), 
        (train_y_notes_ohv, train_y_durations_ohv),
        epochs=2,
        batch_size=32,
        run_folder = save_path,
        print_step_interval=300,
        validation_data = ((val_x_notes, val_x_durations), (val_y_notes_ohv, val_y_durations_ohv))
)
2/2 300/787 loss 3.889 pitch_loss 2.942 duration_loss 0.947 0:00:04.748380
2/2 600/787 loss 3.419 pitch_loss 2.823 duration_loss 0.595 0:00:07.178785
2/2 loss 4.427 pitch_loss 3.568 duration_loss 0.860 val_loss 5.513 val_pitch_loss 4.372 val_duration_loss 1.141 0:00:08.715082
In [19]:
h = lstm_music.train(
        (train_x_notes, train_x_durations), 
        (train_y_notes_ohv, train_y_durations_ohv),
        epochs=5,
        batch_size=32,
        run_folder = save_path,
        print_step_interval=1000,
        validation_data = ((val_x_notes, val_x_durations), (val_y_notes_ohv, val_y_durations_ohv))
)
3/5 loss 4.274 pitch_loss 3.473 duration_loss 0.802 val_loss 5.688 val_pitch_loss 4.479 val_duration_loss 1.209 0:00:06.292834
4/5 loss 4.197 pitch_loss 3.428 duration_loss 0.769 val_loss 5.219 val_pitch_loss 4.314 val_duration_loss 0.905 0:00:12.556787
5/5 loss 4.142 pitch_loss 3.408 duration_loss 0.734 val_loss 5.211 val_pitch_loss 4.323 val_duration_loss 0.889 0:00:18.786610
In [20]:
! {LS_CMD} {save_path}
 ドライブ G のボリューム ラベルは nitta@gm.tsuda.ac.jp - Google... です
 ボリューム シリアル番号は 1983-1116 です

 G:\マイドライブ\DeepLearning\book14\ch07.self\run のディレクトリ

2022/01/13  09:25    <DIR>          .
2021/12/20  11:19    <DIR>          ..
2021/12/18  12:03           762,452 music_params.pkl
2022/01/13  09:25               651 params.pkl
2022/01/13  09:25    <DIR>          weights
2022/01/13  09:24               121 params_1.pkl
2022/01/13  09:25               327 params_2.pkl
2022/01/13  09:25               651 params_5.pkl
2021/12/18  21:05             1,191 params_10.pkl
2021/12/18  21:52               627 output_0_10.mid
2021/12/18  21:33               611 output_1_10.mid
2021/12/18  21:33               611 output_2_10.mid
2021/12/18  21:46             7,434 params_50.pkl
2021/12/18  21:54               611 output_0_50.mid
2021/12/18  21:49               611 output_1_50.mid
2021/12/18  21:49               611 output_2_50.mid
2021/12/18  21:55            15,534 params_100.pkl
2021/12/18  22:02            23,634 params_150.pkl
2021/12/18  22:10               611 output_0_150.mid
2021/12/18  22:10               611 output_1_150.mid
2021/12/18  22:10               611 output_2_150.mid
2021/12/18  22:10            31,734 params_200.pkl
2021/12/18  22:18            39,834 params_250.pkl
2021/12/18  22:25            47,935 params_300.pkl
2021/12/18  22:33            56,035 params_350.pkl
2021/12/18  22:49               611 output_0_350.mid
2021/12/18  22:37               611 output_1_350.mid
2021/12/18  22:37               611 output_2_350.mid
2021/12/18  22:41            64,135 params_400.pkl
2021/12/18  22:48            72,235 params_450.pkl
2021/12/18  22:51               611 output_0_450.mid
2021/12/18  22:56            80,335 params_500.pkl
2021/12/18  22:57               611 output_0_500.mid
2021/12/18  22:57               611 output_1_500.mid
2021/12/18  22:57               611 output_2_500.mid
2021/12/18  23:04            88,435 params_550.pkl
2021/12/18  23:07               611 output_0_550.mid
2021/12/18  23:08               611 output_1_550.mid
2021/12/18  23:08               611 output_2_550.mid
2021/12/18  23:11            96,535 params_600.pkl
2021/12/18  23:11               611 output_0_600.mid
2021/12/18  23:11               611 output_1_600.mid
2021/12/18  23:11               611 output_2_600.mid
2021/12/18  23:19           104,635 params_650.pkl
2021/12/18  23:26           112,735 params_700.pkl
2021/12/18  23:34           120,835 params_750.pkl
2021/12/18  23:41           128,935 params_800.pkl
2021/12/20  10:33               611 output_0_800.mid
2021/12/20  10:33               611 output_1_800.mid
2021/12/20  10:33               627 output_2_800.mid
              47 個のファイル           1,871,655 バイト

 G:\マイドライブ\DeepLearning\book14\ch07.self\run\weights のディレクトリ

2022/01/13  09:25    <DIR>          .
2022/01/13  09:25    <DIR>          ..
2022/01/13  09:25         4,693,472 weights.h5
2022/01/13  09:25         4,193,232 weights_att.h5
2022/01/13  09:24         4,693,472 weights_1.h5
2022/01/13  09:24         4,193,232 weights_att_1.h5
2022/01/13  09:25         4,693,472 weights_2.h5
2022/01/13  09:25         4,193,232 weights_att_2.h5
2022/01/13  09:25         4,693,472 weights_5.h5
2022/01/13  09:25         4,193,232 weights_att_5.h5
2021/12/18  21:05         4,693,480 weights_10.h5
2021/12/18  21:05         4,193,280 weights_att_10.h5
2021/12/18  21:46         4,693,472 weights_50.h5
2021/12/18  21:46         4,193,232 weights_att_50.h5
2021/12/18  21:55         4,693,472 weights_100.h5
2021/12/18  21:55         4,193,232 weights_att_100.h5
2021/12/18  22:02         4,693,472 weights_150.h5
2021/12/18  22:02         4,193,232 weights_att_150.h5
2021/12/18  22:10         4,693,472 weights_200.h5
2021/12/18  22:10         4,193,232 weights_att_200.h5
2021/12/18  22:18         4,693,472 weights_250.h5
2021/12/18  22:18         4,193,232 weights_att_250.h5
2021/12/18  22:25         4,693,472 weights_300.h5
2021/12/18  22:25         4,193,232 weights_att_300.h5
2021/12/18  22:33         4,693,472 weights_350.h5
2021/12/18  22:33         4,193,232 weights_att_350.h5
2021/12/18  22:41         4,693,472 weights_400.h5
2021/12/18  22:41         4,193,232 weights_att_400.h5
2021/12/18  22:48         4,693,472 weights_450.h5
2021/12/18  22:48         4,193,232 weights_att_450.h5
2021/12/18  22:56         4,693,472 weights_500.h5
2021/12/18  22:56         4,193,232 weights_att_500.h5
2021/12/18  23:04         4,693,472 weights_550.h5
2021/12/18  23:04         4,193,232 weights_att_550.h5
2021/12/18  23:11         4,693,472 weights_600.h5
2021/12/18  23:11         4,193,232 weights_att_600.h5
2021/12/18  23:19         4,693,472 weights_650.h5
2021/12/18  23:19         4,193,232 weights_att_650.h5
2021/12/18  23:26         4,693,472 weights_700.h5
2021/12/18  23:26         4,193,232 weights_att_700.h5
2021/12/18  23:34         4,693,472 weights_750.h5
2021/12/18  23:34         4,193,232 weights_att_750.h5
2021/12/18  23:41         4,693,472 weights_800.h5
2021/12/18  23:41         4,193,232 weights_att_800.h5
              42 個のファイル         186,620,840 バイト

     ファイルの総数:
              89 個のファイル         188,492,495 バイト
               5 個のディレクトリ  106,950,303,744 バイトの空き領域
In [21]:
%matplotlib inline
import matplotlib.pyplot as plt

LSTMMusic.plot_history(
    [
        lstm_music.losses, lstm_music.n_losses, lstm_music.d_losses,
        lstm_music.val_losses, lstm_music.val_n_losses, lstm_music.val_d_losses
    ],
    [
        'loss', 'pitch_loss', 'duration_loss', 'val_loss', 'val_pitch_loss', 'val_duration_loss' 
    ]
)
In [21]:
lstm_music2 = LSTMMusic.load(save_path)
print(lstm_music2.epochs)
5
In [22]:
h=lstm_music2.train(
        (train_x_notes, train_x_durations), 
        (train_y_notes_ohv, train_y_durations_ohv),
        epochs=10,
        batch_size=32,
        run_folder = save_path,
        print_step_interval=1000,
        validation_data = ((val_x_notes, val_x_durations), (val_y_notes_ohv, val_y_durations_ohv))
)
6/10 loss 4.278 pitch_loss 3.496 duration_loss 0.782 val_loss 4.873 val_pitch_loss 3.975 val_duration_loss 0.898 0:00:11.048121
7/10 loss 4.089 pitch_loss 3.376 duration_loss 0.713 val_loss 4.705 val_pitch_loss 3.898 val_duration_loss 0.807 0:00:25.836999
8/10 loss 3.983 pitch_loss 3.316 duration_loss 0.667 val_loss 4.600 val_pitch_loss 3.832 val_duration_loss 0.768 0:00:41.062060
9/10 loss 3.973 pitch_loss 3.308 duration_loss 0.665 val_loss 4.449 val_pitch_loss 3.697 val_duration_loss 0.752 0:00:49.064355
10/10 loss 3.887 pitch_loss 3.253 duration_loss 0.635 val_loss 4.450 val_pitch_loss 3.702 val_duration_loss 0.748 0:00:57.467114
In [23]:
%matplotlib inline
import matplotlib.pyplot as plt

LSTMMusic.plot_history(
    [
        lstm_music2.losses, lstm_music2.n_losses, lstm_music2.d_losses,
        lstm_music2.val_losses, lstm_music2.val_n_losses, lstm_music2.val_d_losses
    ],
    [
        'loss', 'pitch_loss', 'duration_loss', 
        'val_loss', 'val_pitch_loss', 'val_duration_loss' 
    ]
)
In [ ]: