Jul/30/2023 updated by Yoshihisa Nitta
Feb/03/2023 written by Yoshihisa Nitta
Google Colab 対応コード

In [1]:
# by nitta
import os
is_colab = 'google.colab' in str(get_ipython())   # for Google Colab

if is_colab:
    from google.colab import drive
    drive.mount('/content/drive')
    SAVE_PREFIX='/content/drive/MyDrive/DeepLearning/openai_gym_rt'
else:
    SAVE_PREFIX='.'
Mounted at /content/drive

Reinforced Training (強化学習) と Deep Reinforced Training (深層強化学習)

1. Google Colab

In [2]:
# Check if running on Colab
#is_colab = 'google.colab' in str(get_ipython())   # for Google Colab

# packages
if is_colab:
    !apt update -qq
    !apt upgrade -qq
    !apt install -qq xvfb
    !pip -q install pyvirtualdisplay

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import matplotlib
matplotlib.rcParams['animation.embed_limit'] = 2**32
39 packages can be upgraded. Run 'apt list --upgradable' to see them.
The following packages have been kept back:
  libcudnn8 libcudnn8-dev libnccl-dev libnccl2
The following packages will be upgraded:
  base-files bash binutils binutils-common binutils-x86-64-linux-gnu coreutils cuda-compat-12-2
  cuda-keyring cuda-toolkit-12-config-common cuda-toolkit-config-common dpkg dpkg-dev libbinutils
  libc-bin libctf-nobfd0 libctf0 libdpkg-perl libgnutls30 libldap-2.5-0 libpam-modules
  libpam-modules-bin libpam-runtime libpam0g libperl5.34 libprocps8 libudev1 linux-libc-dev login
  openssl passwd perl perl-base perl-modules-5.34 procps tar
35 upgraded, 0 newly installed, 0 to remove and 4 not upgraded.
Need to get 58.1 MB of archives.
After this operation, 63.5 kB of additional disk space will be used.
Extracting templates from packages: 100%
Preconfiguring packages ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../base-files_12ubuntu4.6_amd64.deb ...
Unpacking base-files (12ubuntu4.6) over (12ubuntu4.4) ...
Setting up base-files (12ubuntu4.6) ...
Installing new version of config file /etc/issue ...
Installing new version of config file /etc/issue.net ...
Installing new version of config file /etc/lsb-release ...
Installing new version of config file /etc/update-motd.d/10-help-text ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../bash_5.1-6ubuntu1.1_amd64.deb ...
Unpacking bash (5.1-6ubuntu1.1) over (5.1-6ubuntu1) ...
Setting up bash (5.1-6ubuntu1.1) ...
update-alternatives: using /usr/share/man/man7/bash-builtins.7.gz to provide /usr/share/man/man7/builtins.7.gz (builtins.7.gz) in auto mode
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../coreutils_8.32-4.1ubuntu1.1_amd64.deb ...
Unpacking coreutils (8.32-4.1ubuntu1.1) over (8.32-4.1ubuntu1) ...
Setting up coreutils (8.32-4.1ubuntu1.1) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../tar_1.34+dfsg-1ubuntu0.1.22.04.2_amd64.deb ...
Unpacking tar (1.34+dfsg-1ubuntu0.1.22.04.2) over (1.34+dfsg-1ubuntu0.1.22.04.1) ...
Setting up tar (1.34+dfsg-1ubuntu0.1.22.04.2) ...
update-alternatives: warning: forcing reinstallation of alternative /usr/sbin/rmt-tar because link group rmt is broken
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../dpkg_1.21.1ubuntu2.3_amd64.deb ...
Unpacking dpkg (1.21.1ubuntu2.3) over (1.21.1ubuntu2.2) ...
Setting up dpkg (1.21.1ubuntu2.3) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../login_1%3a4.8.1-2ubuntu2.2_amd64.deb ...
Unpacking login (1:4.8.1-2ubuntu2.2) over (1:4.8.1-2ubuntu2.1) ...
Setting up login (1:4.8.1-2ubuntu2.2) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libperl5.34_5.34.0-3ubuntu1.3_amd64.deb ...
Unpacking libperl5.34:amd64 (5.34.0-3ubuntu1.3) over (5.34.0-3ubuntu1.2) ...
Preparing to unpack .../perl_5.34.0-3ubuntu1.3_amd64.deb ...
Unpacking perl (5.34.0-3ubuntu1.3) over (5.34.0-3ubuntu1.2) ...
Preparing to unpack .../perl-base_5.34.0-3ubuntu1.3_amd64.deb ...
Unpacking perl-base (5.34.0-3ubuntu1.3) over (5.34.0-3ubuntu1.2) ...
Setting up perl-base (5.34.0-3ubuntu1.3) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../perl-modules-5.34_5.34.0-3ubuntu1.3_all.deb ...
Unpacking perl-modules-5.34 (5.34.0-3ubuntu1.3) over (5.34.0-3ubuntu1.2) ...
Preparing to unpack .../libc-bin_2.35-0ubuntu3.6_amd64.deb ...
Unpacking libc-bin (2.35-0ubuntu3.6) over (2.35-0ubuntu3.4) ...
Setting up libc-bin (2.35-0ubuntu3.6) ...
/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link

(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libpam0g_1.4.0-11ubuntu2.4_amd64.deb ...
Unpacking libpam0g:amd64 (1.4.0-11ubuntu2.4) over (1.4.0-11ubuntu2.3) ...
Setting up libpam0g:amd64 (1.4.0-11ubuntu2.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libpam-modules-bin_1.4.0-11ubuntu2.4_amd64.deb ...
Unpacking libpam-modules-bin (1.4.0-11ubuntu2.4) over (1.4.0-11ubuntu2.3) ...
Setting up libpam-modules-bin (1.4.0-11ubuntu2.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libpam-modules_1.4.0-11ubuntu2.4_amd64.deb ...
Unpacking libpam-modules:amd64 (1.4.0-11ubuntu2.4) over (1.4.0-11ubuntu2.3) ...
Setting up libpam-modules:amd64 (1.4.0-11ubuntu2.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libpam-runtime_1.4.0-11ubuntu2.4_all.deb ...
Unpacking libpam-runtime (1.4.0-11ubuntu2.4) over (1.4.0-11ubuntu2.3) ...
Setting up libpam-runtime (1.4.0-11ubuntu2.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libudev1_249.11-0ubuntu3.12_amd64.deb ...
Unpacking libudev1:amd64 (249.11-0ubuntu3.12) over (249.11-0ubuntu3.10) ...
Setting up libudev1:amd64 (249.11-0ubuntu3.12) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../passwd_1%3a4.8.1-2ubuntu2.2_amd64.deb ...
Unpacking passwd (1:4.8.1-2ubuntu2.2) over (1:4.8.1-2ubuntu2.1) ...
Setting up passwd (1:4.8.1-2ubuntu2.2) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libgnutls30_3.7.3-4ubuntu1.4_amd64.deb ...
Unpacking libgnutls30:amd64 (3.7.3-4ubuntu1.4) over (3.7.3-4ubuntu1.2) ...
Setting up libgnutls30:amd64 (3.7.3-4ubuntu1.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../00-libprocps8_2%3a3.3.17-6ubuntu2.1_amd64.deb ...
Unpacking libprocps8:amd64 (2:3.3.17-6ubuntu2.1) over (2:3.3.17-6ubuntu2) ...
Preparing to unpack .../01-procps_2%3a3.3.17-6ubuntu2.1_amd64.deb ...
Unpacking procps (2:3.3.17-6ubuntu2.1) over (2:3.3.17-6ubuntu2) ...
Preparing to unpack .../02-openssl_3.0.2-0ubuntu1.15_amd64.deb ...
Unpacking openssl (3.0.2-0ubuntu1.15) over (3.0.2-0ubuntu1.12) ...
Preparing to unpack .../03-libctf0_2.38-4ubuntu2.6_amd64.deb ...
Unpacking libctf0:amd64 (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../04-libctf-nobfd0_2.38-4ubuntu2.6_amd64.deb ...
Unpacking libctf-nobfd0:amd64 (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../05-binutils-x86-64-linux-gnu_2.38-4ubuntu2.6_amd64.deb ...
Unpacking binutils-x86-64-linux-gnu (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../06-libbinutils_2.38-4ubuntu2.6_amd64.deb ...
Unpacking libbinutils:amd64 (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../07-binutils_2.38-4ubuntu2.6_amd64.deb ...
Unpacking binutils (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../08-binutils-common_2.38-4ubuntu2.6_amd64.deb ...
Unpacking binutils-common:amd64 (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../09-cuda-compat-12-2_535.161.08-1_amd64.deb ...
Unpacking cuda-compat-12-2 (535.161.08-1) over (535.129.03-1) ...
Preparing to unpack .../10-cuda-keyring_1.1-1_all.deb ...
Unpacking cuda-keyring (1.1-1) over (1.0-1) ...
Preparing to unpack .../11-cuda-toolkit-12-config-common_12.4.99-1_all.deb ...
Unpacking cuda-toolkit-12-config-common (12.4.99-1) over (12.3.52-1) ...
Preparing to unpack .../12-cuda-toolkit-config-common_12.4.99-1_all.deb ...
Unpacking cuda-toolkit-config-common (12.4.99-1) over (12.3.52-1) ...
Preparing to unpack .../13-dpkg-dev_1.21.1ubuntu2.3_all.deb ...
Unpacking dpkg-dev (1.21.1ubuntu2.3) over (1.21.1ubuntu2.2) ...
Preparing to unpack .../14-libdpkg-perl_1.21.1ubuntu2.3_all.deb ...
Unpacking libdpkg-perl (1.21.1ubuntu2.3) over (1.21.1ubuntu2.2) ...
Preparing to unpack .../15-libldap-2.5-0_2.5.17+dfsg-0ubuntu0.22.04.1_amd64.deb ...
Unpacking libldap-2.5-0:amd64 (2.5.17+dfsg-0ubuntu0.22.04.1) over (2.5.16+dfsg-0ubuntu0.22.04.1) ...
Preparing to unpack .../16-linux-libc-dev_5.15.0-101.111_amd64.deb ...
Unpacking linux-libc-dev:amd64 (5.15.0-101.111) over (5.15.0-88.98) ...
Setting up cuda-toolkit-config-common (12.4.99-1) ...
Setting up cuda-compat-12-2 (535.161.08-1) ...
Setting up binutils-common:amd64 (2.38-4ubuntu2.6) ...
Setting up linux-libc-dev:amd64 (5.15.0-101.111) ...
Setting up libctf-nobfd0:amd64 (2.38-4ubuntu2.6) ...
Setting up perl-modules-5.34 (5.34.0-3ubuntu1.3) ...
Setting up libldap-2.5-0:amd64 (2.5.17+dfsg-0ubuntu0.22.04.1) ...
Setting up cuda-keyring (1.1-1) ...
Setting up libbinutils:amd64 (2.38-4ubuntu2.6) ...
Setting up openssl (3.0.2-0ubuntu1.15) ...
Setting up cuda-toolkit-12-config-common (12.4.99-1) ...
Setting up libprocps8:amd64 (2:3.3.17-6ubuntu2.1) ...
Setting up libctf0:amd64 (2.38-4ubuntu2.6) ...
Setting up libperl5.34:amd64 (5.34.0-3ubuntu1.3) ...
Setting up perl (5.34.0-3ubuntu1.3) ...
Setting up libdpkg-perl (1.21.1ubuntu2.3) ...
Setting up procps (2:3.3.17-6ubuntu2.1) ...
Setting up binutils-x86-64-linux-gnu (2.38-4ubuntu2.6) ...
Setting up binutils (2.38-4ubuntu2.6) ...
Setting up dpkg-dev (1.21.1ubuntu2.3) ...
Processing triggers for man-db (2.10.2-1) ...
Processing triggers for libc-bin (2.35-0ubuntu3.6) ...
/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link

The following additional packages will be installed:
  libfontenc1 libxfont2 libxkbfile1 x11-xkb-utils xfonts-base xfonts-encodings xfonts-utils
  xserver-common
The following NEW packages will be installed:
  libfontenc1 libxfont2 libxkbfile1 x11-xkb-utils xfonts-base xfonts-encodings xfonts-utils
  xserver-common xvfb
0 upgraded, 9 newly installed, 0 to remove and 4 not upgraded.
Need to get 7,814 kB of archives.
After this operation, 11.9 MB of additional disk space will be used.
Selecting previously unselected package libfontenc1:amd64.
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../0-libfontenc1_1%3a1.1.4-1build3_amd64.deb ...
Unpacking libfontenc1:amd64 (1:1.1.4-1build3) ...
Selecting previously unselected package libxfont2:amd64.
Preparing to unpack .../1-libxfont2_1%3a2.0.5-1build1_amd64.deb ...
Unpacking libxfont2:amd64 (1:2.0.5-1build1) ...
Selecting previously unselected package libxkbfile1:amd64.
Preparing to unpack .../2-libxkbfile1_1%3a1.1.0-1build3_amd64.deb ...
Unpacking libxkbfile1:amd64 (1:1.1.0-1build3) ...
Selecting previously unselected package x11-xkb-utils.
Preparing to unpack .../3-x11-xkb-utils_7.7+5build4_amd64.deb ...
Unpacking x11-xkb-utils (7.7+5build4) ...
Selecting previously unselected package xfonts-encodings.
Preparing to unpack .../4-xfonts-encodings_1%3a1.0.5-0ubuntu2_all.deb ...
Unpacking xfonts-encodings (1:1.0.5-0ubuntu2) ...
Selecting previously unselected package xfonts-utils.
Preparing to unpack .../5-xfonts-utils_1%3a7.7+6build2_amd64.deb ...
Unpacking xfonts-utils (1:7.7+6build2) ...
Selecting previously unselected package xfonts-base.
Preparing to unpack .../6-xfonts-base_1%3a1.0.5_all.deb ...
Unpacking xfonts-base (1:1.0.5) ...
Selecting previously unselected package xserver-common.
Preparing to unpack .../7-xserver-common_2%3a21.1.4-2ubuntu1.7~22.04.8_all.deb ...
Unpacking xserver-common (2:21.1.4-2ubuntu1.7~22.04.8) ...
Selecting previously unselected package xvfb.
Preparing to unpack .../8-xvfb_2%3a21.1.4-2ubuntu1.7~22.04.8_amd64.deb ...
Unpacking xvfb (2:21.1.4-2ubuntu1.7~22.04.8) ...
Setting up libfontenc1:amd64 (1:1.1.4-1build3) ...
Setting up xfonts-encodings (1:1.0.5-0ubuntu2) ...
Setting up libxkbfile1:amd64 (1:1.1.0-1build3) ...
Setting up libxfont2:amd64 (1:2.0.5-1build1) ...
Setting up x11-xkb-utils (7.7+5build4) ...
Setting up xfonts-utils (1:7.7+6build2) ...
Setting up xfonts-base (1:1.0.5) ...
Setting up xserver-common (2:21.1.4-2ubuntu1.7~22.04.8) ...
Setting up xvfb (2:21.1.4-2ubuntu1.7~22.04.8) ...
Processing triggers for man-db (2.10.2-1) ...
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
Processing triggers for libc-bin (2.35-0ubuntu3.6) ...
/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link

In [3]:
# Show multiple images as animation
# Colab compatible
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from IPython import display
import os

import gym
def display_frames_as_anim(frames, filepath=None):
    """
    Displays a list of frames as a gif, with controls
    """
    H, W, _ = frames[0].shape
    fig, ax = plt.subplots(1, 1, figsize=(W/100.0, H/100.0))
    ax.axis('off')
    patch = plt.imshow(frames[0])

    def animate(i):
        display.clear_output(wait=True)
        patch.set_data(frames[i])
        return patch

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50, repeat=False)

    if not filepath is None:
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        anim.save(filepath)

    if is_colab:
        display.display(display.HTML(anim.to_jshtml()))
        #plt.close()
    else:
        plt.show()

2. OpenAI Gym : CartPole

In [4]:
import gym

ENV = 'CartPole-v1'
/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.
  and should_run_async(code)
In [5]:
# ランダムに行動する
np.random.seed(12345)

env = gym.make(ENV, new_step_api=True, render_mode='rgb_array') # new_step_api, render_mode
observation = env.reset()

frames = [ env.render()[0] ]
states = [ observation ]
for step in range(100):
    action = env.action_space.sample()  ## np.random.choice(2) # 0: left, 1: right  ##
    obserbation, reward, done, truncated, info = env.step(action) # when new_step_api==True, two bool (done, truncated) returned
    frames.append(env.render()[0])  # env.render(mode='rgb_array') -> env.render()[0]
    states.append(observation)

    if done:
        break

print(len(states))
print(states)
22
[array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32), array([-0.02962237, -0.02631105, -0.02427756,  0.02858146], dtype=float32)]
/usr/local/lib/python3.10/dist-packages/gym/utils/passive_env_checker.py:241: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)
  if not isinstance(terminated, (bool, np.bool8)):
In [6]:
# アニメーション表示する
%matplotlib notebook
import matplotlib.pyplot as plt

#display_frames_as_anim(frames)   # display now
display_frames_as_anim(frames, 'cartpole_video4a.mp4')   # save to file

if is_colab: # copy to google drive
    ! mkdir -p {SAVE_PREFIX}
    ! cp cartpole_video4a.mp4 {SAVE_PREFIX}     # copy to the Google Drive

実行の様子を動画で表示する(HTML)

3. 深層強化学習

ディープラーニングを使わない従来の強化学習として3通りの手法があった。

  • Q学習
  • Sarsa
  • 方策反復法

このうち、Q学習における行動価値関数 $Q(s,a)$ をディープ・ニューラルネットワークで実現する DQN が提案された。

Prioritized Experience Replay は、DQN や DDQN における "Experience Replay" を工夫した手法である。学習に使用する transition をランダムに選択するのではなく、優先順位によって選択する。 教師信号との差である $| R_{t+1} + \gamma \max_{a} Q_t(S_{t+1}, a) - Q(S_t, a_t) |$ が大きい場合は、その行動価値関数 $Q(S_t, a_t)$ に対して学習が進んでいないことになるので、優先的に学習する。

In [7]:
# CartPole で観測した状態変数の値を名前をつけて保存するために namedtuple を使う
from collections import namedtuple

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
In [8]:
# 定数

GAMMA = 0.99
MAX_STEPS = 200
NUM_EPISODES = 500
In [9]:
# Replay Memory
import random

class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.index = 0

    def push(self, state, action, state_next, reward):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.index] = Transition(state, action, state_next, reward)
        self.index = (self.index + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

6.4 Prioritized Experience Replay

Prioritized Experience Replay の概要

TD 誤差を記録するためのクラス TDerrorMemory を実装する。 基本的には ReplayMemory クラスと同様だが、関数 get_prioritized_indexes と関数 update_td_error を用意する。

  • 関数 get_prioritized_indexes() ... メモリに格納されている TD 誤差の大きさに応じて確率的に index を求める。 ただし、TD 誤差の絶対値を求める際に微小値 TD_ERROR_EPSILON を加算している。
  • 関数 update_td_error() ... メモリに格納されている TD 誤差を更新する。
In [10]:
# p.163

TD_ERROR_EPSILON = 0.0001   # error to add to the bias

class TDerrorMemory:

    def __init__(self, CAPACITY):
        self.capacity = CAPACITY
        self.memory = []
        self.index = 0

    def push(self, td_error):
        '''
        TD 誤差をメモリに保存する
        '''
        if len(self.memory) < self.capacity:
            self.memory.append(None)

        self.memory[self.index] = td_error
        self.index = (self.index + 1) % self.capacity

    def __len__(self):
        return len(self.memory)

    def get_prioritized_indexes(self, batch_size):
        '''
        TD 誤差に応じた確率で index を取得する
        '''

        # TD 誤差の和を計算する
        sum_absolute_td_error = np.sum(np.absolute(self.memory))
        sum_absolute_td_error += TD_ERROR_EPSILON * len(self.memory)

        # batch_size 分の乱数を生成して、昇順に並べる
        rand_list = np.random.uniform(0, sum_absolute_td_error, batch_size)
        rand_list = np.sort(rand_list)

        # 作成した乱数で串刺しにして、インデックスを求める
        indexes = []
        idx = 0
        tmp_sum_absolute_td_error = 0
        for rand_num in rand_list:
            while tmp_sum_absolute_td_error < rand_num:
                tmp_sum_absolute_td_error += abs(self.memory[idx]) + TD_ERROR_EPSILON
                idx += 1

            # 微小値を計算に使用したため、index がメモリの長さを超えた場合の補正
            if idx >= len(self.memory):
                idx = len(self.memory) - 1
            indexes.append(idx)

        return indexes
In [11]:
# p.159
import torch

class Net(torch.nn.Module):

    def __init__(self, n_in, n_mid, n_out):
        super().__init__()
        self.fc1 = torch.nn.Linear(n_in, n_mid)
        self.fc2 = torch.nn.Linear(n_mid, n_mid)

        ##################
        # Dueling Network
        ##################
        self.fc3_adv = torch.nn.Linear(n_mid, n_out)  # Dueling Network
        self.fc3_v = torch.nn.Linear(n_mid, 1)       # 価値 V

    def forward(self, x):
        h1 = torch.nn.functional.relu(self.fc1(x))
        h2 = torch.nn.functional.relu(self.fc2(h1))

        ################
        # Dueling Network (どちらも ReLU しないことに注意)
        ################
        adv = self.fc3_adv(h2)                        # shape (minibatch, 2) , adv.size(2)==2
        val = self.fc3_v(h2).expand(-1, adv.size(1))  # shapeを変換しておく (minibatch, 1)-->(minibatch, 2)

        # val+adv から advの平均を引く
        output = val + adv - adv.mean(dim=1, keepdim=True).expand(-1, adv.size(1))

        return output    # shape (minibatch, 2)

関数 replay() を Prioritized Experience Replay に変更するが、 学習の初期段階(episolde < 30)ではまだ学習が進んでいないので従来通り乱数で選択する。 このため、引数に変数 episode を追加している。

関数 make_minibatch() にも引数に変数 episode を追加している。

関数 update_td_error_memory() では、メモリオブジェクトに保存された全 trainsition の TD 誤差を再計算する。 PyTorch で計算した結果は Tensor 型なので、一旦 Numpy のデータ型にしてから、Python の list 型に変換する。

In [12]:
# p.164

BATCH_SIZE = 32
CAPACITY = 10000

class Brain:

    def __init__(self, num_states, num_actions):
        self.num_actions = num_actions
        self.memory = ReplayMemory(CAPACITY)   # 経験を記憶するメモリ

        n_in, n_mid, n_out = num_states, 32, num_actions   # main と target という2つのネットワークを生成する
        self.main_q_network = Net(n_in, n_mid, n_out)
        self.target_q_network = Net(n_in, n_mid, n_out)

        self.optimizer = torch.optim.Adam(   # 最適化
            self.main_q_network.parameters(),
            lr = 0.0001
        )

        ##############################
        # Prioritzed Experience Replay
        ##############################
        self.td_error_memory = TDerrorMemory(CAPACITY)  # !!! TD誤差のメモリオブジェクトを生成する
        ### End of change ###

    def replay(self, episode):     # !!! added episode argument
        '''
        Experience Replay で学習する
        '''

        # 1. メモリサイズを確認する
        if len(self.memory) < BATCH_SIZE:
            return

        # 2. mini-batch を作成する
        self.batch, self.state_batch, self.action_batch, self.reward_batch, self.non_final_next_states = self.make_minibatch(episode)  # !!! argument added

        # 3. 教師信号となる Q(s_t, a_t) を求める
        self.expected_state_action_values = self.get_expected_state_action_values()

        # 4. パラメータを更新する
        self.update_main_q_network()

    def decide_action(self, state, episode):
        # ε-greedy 法で徐々に最適行動の選択を増やす
        epsilon = 0.5 * (1 / (episode + 1))

        if epsilon <= np.random.uniform(0, 1):
            self.main_q_network.eval()   # 推論モードに切り替える
            with torch.no_grad():
                # torch.max(dim=1) は最大値とインデックスのtupleが返される
                # インデックスにアクセスするのは [1] or .index
                # .view(1, 1) により [torch.LongTensor(1,1)] にreshapeされる
                action = self.main_q_network(state).max(dim=1).indices.view(1, 1)

        else:
            action = torch.LongTensor(
                [[ random.randrange(self.num_actions) ]]   # 0 or 1 の乱数
            )

        return action

    def make_minibatch(self, episode):   # !!! argument episode added
        '''
        2. mini-batch を作成する
        '''

        # 2.1 メモリから mini-batch 分のデータを取り出す
        ###############################
        # Prioritized Experience Replay
        ###############################
        if episode < 30:
            transitions = self.memory.sample(BATCH_SIZE)
        else:
            # TD 誤差に応じてミニバッチを取り出す
            indexes = self.td_error_memory.get_prioritized_indexes(BATCH_SIZE)
            transitions = [ self.memory.memory[n] for n in indexes ]
        # transitions = self.memory.sample(BATCH_SIZE)
        ### End of change ###


        # 2.2 各変数を mini-batch に対応する形へ変形する
        # trainsition = [ (state, action, state_next, reward) ...]
        # これを変形して次の形式にする
        # ( [state, ...], [action, ...], [state_next, ...], [reward, ...])
        batch = Transition(*zip(*transitions))

        # 2.3 'Transitions' named tuple から項目を取り出し、
        # concatenate して torch.FloatTensor size(BATCH_SIZE, 1) に変換する
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

        return batch, state_batch, action_batch, reward_batch, non_final_next_states

    def get_expected_state_action_values(self):
        '''
        3. Q(s_t, a_t) を求める
        '''

        # 3.1 推論モードに切り替える
        self.main_q_network.eval()
        self.target_q_network.eval()

        # 3.2 Q(s_t, a_t) を求める。
        # self.model(state_batch) は各action(左 or 右)に対するQ値を全部返すので
        # 返り値は "torch.FloatTensor size (BATCH_SIZE, 2)"
        # 実行した action a_t に対応する Q 値を gather(dim, index) で引っ張り出す。
        self.state_action_values = self.main_q_network(
            self.state_batch
        ).gather(1, self.action_batch)

        # 3.3 次の状態がある index のmax{Q(s_{t+1}, a)} 値を求める。
        # cartpole が done になっておらず、next_state があるかをチェックするインデックスマスクを作成する
        non_final_mask = torch.ByteTensor(
            tuple(map(lambda s: s is not None, self.batch.next_state))
            )
        next_state_values = torch.zeros(BATCH_SIZE)  # initial value is 0
        a_m = torch.zeros(BATCH_SIZE).type(torch.LongTensor)

        # 次の状態での最大の行動 a_m を main_q_network から求める。
        # 出力にアクセスし、max(dim=1)で列方向の最大値の [値, index] を求める
        # (注意) PyTorch で Tensor 配列の最大値を求めるには torch.max(dim=n) を使う
        # dimを指定した場合は値とインデックスのtupleが返されるので
        # [1] or .indices としてそのインデックスの値 (index=1 ) を出力する
        # detach でその値を取り出す
        a_m[non_final_mask] = self.main_q_network(
            self.non_final_next_states
        ).detach().max(dim=1).indices

        # 次の状態のあるものだけフィルターし、size 32 を (32, 1) へ変形する
        a_m_non_final_next_states = a_m[non_final_mask].view(-1, 1)

        # 次の状態がある index の、行動 a_m の Q 値を target Q-Network から求める
        # detach() で取り出す。
        # squeeze で (minibatch, 1) を (minibatch,) へ
        next_state_values[non_final_mask] = self.target_q_network(
            self.non_final_next_states
        ).gather(1, a_m_non_final_next_states).detach().squeeze()

        # 3.4 教師となる Q(s_t, a_t) 値を、Q学習の式から求める。
        expected_state_action_values = self.reward_batch + GAMMA * next_state_values

        return expected_state_action_values

    def update_main_q_network(self):
        '''
        4. パラメータの更新
        '''

        # 4.1 訓練モードに切り替える
        self.main_q_network.train()

        # 4.2 損失関数 (smooth_l1_loss は Huberloss)
        # expected_state_action_values は形状が (minibatch,)なので、
        # unsqueezeで形状を (minibatch, 1) へ
        loss = torch.nn.functional.smooth_l1_loss(
            self.state_action_values,
            self.expected_state_action_values.unsqueeze(1)
        )

        # 4.3 パラメータを更新する
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_q_network(self):
        '''
        target_q_network を main_q_network と同じにする
        '''
        self.target_q_network.load_state_dict(self.main_q_network.state_dict())

    ###############################
    # Prioritized Experience Replay
    ###############################
    def update_td_error_memory(self):
        '''
        TD 誤差メモリに格納されている TD 誤差を更新する
        '''
        self.main_q_network.eval()
        self.target_q_network.eval()

        # 全メモリでミニバッチを作成する
        transitions = self.memory.memory
        batch = Transition(*zip(*transitions))

        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

        # ネットワークが出力する Q(s_t, a_t)
        state_action_values = self.main_q_network(
            state_batch
        ).gather(1, action_batch)

        # next_state があるかをチェックするインデックスマスク
        non_final_mask = torch.ByteTensor(
            tuple(map(lambda s: s is not None, batch.next_state))
        )

        next_state_values = torch.zeros(len(self.memory))
        a_m = torch.zeros(len(self.memory)).type(torch.LongTensor)

        # 次の状態での最大Q値の行動 a_m を Main Q-Network から求める
        a_m[non_final_mask] = self.main_q_network(
            non_final_next_states
        ).detach().max(dim=1).indices

        # 次の状態があるものだけにフィルターし、shape (32,) --> (32, 1)
        a_m_non_final_next_states = a_m[non_final_mask].view(-1, 1)

        #次の状態がある index の、行動 a_m の Q 値を target Q-Network から求める
        next_state_values[non_final_mask] = self.target_q_network(
            non_final_next_states
            ).gather(1, a_m_non_final_next_states).detach().squeeze()

        # TD 誤差を求める
        td_errors = (reward_batch + GAMMA * next_state_values) - state_action_values.squeeze()

        # TD 誤差メモリを更新する
        self.td_error_memory.memory = td_errors.detach().numpy().tolist()
/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.
  and should_run_async(code)

Agent クラスに、関数 memorize_td_error() と関数 update_td_error_memory() を追加する。

関数 memorize_td_error() はその step での TD 誤差を格納する。

関数 update_td_error_memory() は、各試行の最後に実行され、TD誤差を更新する。

Brain クラスの関数 replay() の引数に episode を追加したので、関数 update_q_function() の引数にも追加する。

In [13]:
# p.170

class Agent:
    def __init__(self, num_states, num_actions):
        self.brain = Brain(num_states, num_actions)

    def update_q_function(self, episode):
        self.brain.replay(episode)

    def get_action(self, state, episode):
        action = self.brain.decide_action(state, episode)
        return action

    def memorize(self, state, action, state_next, reward):
        self.brain.memory.push(state, action, state_next, reward)

    def update_target_q_function(self):
        self.brain.update_target_q_network()

    ###############################
    # Prioritized Experience Replay
    ###############################
    def memorize_td_error(self, td_error):
        self.brain.td_error_memory.push(td_error)

    def update_td_error_memory(self):
        self.brain.update_td_error_memory()

Environment クラスは、関数 run() の内容を 3 箇所変更する。

各ステップでの TD 誤差を TD 誤差メモリに追加する。 ただし、今回は 0 を保存している。 各試行の終了時に TD 誤差メモリの中身を更新するので、そのタイミングで正しい値が格納される。

Q Network の更新関数で、episode 変数を追加している。

各試行の最後で TD 誤差メモリの内容を更新する。

In [14]:
# p.171
# [自分へのメモ] env.step() の返り値の変更、env.render() の引数と返り値の変更に対応した
# run() の中で動画(frames)を表示せず、frames を返すように変更した

class Environment:

    def __init__(self):
        self.env = gym.make(ENV, new_step_api=True, render_mode='rgb_array')
        self.num_states = self.env.observation_space.shape[0]
        self.num_actions = self.env.action_space.n
        self.agent = Agent(self.num_states, self.num_actions)

    def run(self):
        episode_10_list = np.zeros(10)  # 10 試行分の立ち続けた step 数を格納し、平均ステップ数を出力に利用
        complete_episodes = 0   # 195 steps 以上立ち続けた試行数
        episode_final = False   # 最後の試行であるか
        frames = []             # 最後の試行を動画にするための配列

        for episode in range(NUM_EPISODES):
            observation = self.env.reset()

            state = observation   # 観測をそのまま状態 s として使用する
            state = torch.from_numpy(state).type(torch.FloatTensor)
            state = torch.unsqueeze(state, dim=0)  # FloatTensor size(4) -> size(1,4)

            for step in range(MAX_STEPS):

                if episode_final is True:
                    frames.append(self.env.render()[0])

                action = self.agent.get_action(state, episode)

                # 行動 a_t の実行により、s_{t+1} と done フラグを求める。
                # (注) new_api=True なのでtruncated も見るべし。

                observation_next, _, done, truncated, _ = self.env.step(action.item())

                if done or truncated:
                    state_next = None  # 次の状態は無い
                    episode_10_list = np.hstack((episode_10_list[1:], step+1))

                    if step < 195:
                        reward = torch.FloatTensor([-1.0])
                        complete_episodes = 0
                    else:
                        reward = torch.FloatTensor([1.0])
                        complete_episodes = complete_episodes + 1

                elif step == MAX_STEPS - 1:
                    state_next = None
                    episode_10_list = np.hstack((episode_10_list[1:], step+1))
                    reward = torch.FloatTensor([1.0])
                    complete_episodes += 1
                    truncated = True

                else:
                    reward = torch.FloatTensor([0.0])
                    state_next = observation_next
                    state_next = torch.from_numpy(state_next).type(torch.FloatTensor)
                    # FloatTensor size (4) --> (1, 4)
                    state_next = torch.unsqueeze(state_next, dim=0)

                # メモリに経験を追加
                self.agent.memorize(state, action, state_next, reward)

                ###############################
                # Prioritized Experience Replay
                ###############################
                self.agent.memorize_td_error(0)   # 本当はTD誤差だが、一旦0としておく

                ###########################################
                # Prioritized Experience Replay で Q 関数を更新する
                ############################################
                self.agent.update_q_function(episode)

                # 観測の更新
                state = state_next

                # 終了時の処理
                if done or truncated:
                    print(f'{episode} Episode: Finished after {step+1} steps: average steps = {episode_10_list.mean(): .1f}')

                    ################################
                    # Prioritiezed Experience Replay
                    ################################
                    self.agent.update_td_error_memory()

                    #########################
                    # Double-DQN により追加
                    #########################
                    # 2試行に一度 target_q_network を main_q_network と同じにする
                    if (episode % 2 == 0):
                        self.agent.update_target_q_function()

                    break

            if episode_final is True:
                # display_frames_as_gif(frames)
                break

            # 10 連続で 200 steps 立ち続けたら成功
            if complete_episodes >= 10:
                print('10 consecutive success')
                episode_final = True

        return frames
In [15]:
# main クラス
cartpole_env = Environment()
frames = cartpole_env.run()
/usr/local/lib/python3.10/dist-packages/gym/utils/passive_env_checker.py:241: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)
  if not isinstance(terminated, (bool, np.bool8)):
0 Episode: Finished after 12 steps: average steps =  1.2
1 Episode: Finished after 16 steps: average steps =  2.8
<ipython-input-12-72237b7ae537>:208: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at ../aten/src/ATen/native/IndexingUtils.h:27.)
  a_m[non_final_mask] = self.main_q_network(
<ipython-input-12-72237b7ae537>:213: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at ../aten/src/ATen/native/IndexingUtils.h:27.)
  a_m_non_final_next_states = a_m[non_final_mask].view(-1, 1)
<ipython-input-12-72237b7ae537>:216: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at ../aten/src/ATen/native/IndexingUtils.h:27.)
  next_state_values[non_final_mask] = self.target_q_network(
<ipython-input-12-72237b7ae537>:129: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at ../aten/src/ATen/native/IndexingUtils.h:27.)
  a_m[non_final_mask] = self.main_q_network(
<ipython-input-12-72237b7ae537>:134: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at ../aten/src/ATen/native/IndexingUtils.h:27.)
  a_m_non_final_next_states = a_m[non_final_mask].view(-1, 1)
<ipython-input-12-72237b7ae537>:139: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at ../aten/src/ATen/native/IndexingUtils.h:27.)
  next_state_values[non_final_mask] = self.target_q_network(
2 Episode: Finished after 12 steps: average steps =  4.0
3 Episode: Finished after 13 steps: average steps =  5.3
4 Episode: Finished after 12 steps: average steps =  6.5
5 Episode: Finished after 9 steps: average steps =  7.4
6 Episode: Finished after 10 steps: average steps =  8.4
7 Episode: Finished after 12 steps: average steps =  9.6
8 Episode: Finished after 11 steps: average steps =  10.7
9 Episode: Finished after 12 steps: average steps =  11.9
10 Episode: Finished after 11 steps: average steps =  11.8
11 Episode: Finished after 11 steps: average steps =  11.3
12 Episode: Finished after 10 steps: average steps =  11.1
13 Episode: Finished after 13 steps: average steps =  11.1
14 Episode: Finished after 13 steps: average steps =  11.2
15 Episode: Finished after 26 steps: average steps =  12.9
16 Episode: Finished after 36 steps: average steps =  15.5
17 Episode: Finished after 41 steps: average steps =  18.4
18 Episode: Finished after 51 steps: average steps =  22.4
19 Episode: Finished after 189 steps: average steps =  40.1
20 Episode: Finished after 179 steps: average steps =  56.9
21 Episode: Finished after 200 steps: average steps =  75.8
22 Episode: Finished after 98 steps: average steps =  84.6
23 Episode: Finished after 71 steps: average steps =  90.4
24 Episode: Finished after 61 steps: average steps =  95.2
25 Episode: Finished after 69 steps: average steps =  99.5
26 Episode: Finished after 75 steps: average steps =  103.4
27 Episode: Finished after 109 steps: average steps =  110.2
28 Episode: Finished after 200 steps: average steps =  125.1
29 Episode: Finished after 200 steps: average steps =  126.2
30 Episode: Finished after 200 steps: average steps =  128.3
31 Episode: Finished after 200 steps: average steps =  128.3
32 Episode: Finished after 200 steps: average steps =  138.5
33 Episode: Finished after 200 steps: average steps =  151.4
34 Episode: Finished after 200 steps: average steps =  165.3
35 Episode: Finished after 104 steps: average steps =  168.8
36 Episode: Finished after 200 steps: average steps =  181.3
37 Episode: Finished after 200 steps: average steps =  190.4
38 Episode: Finished after 200 steps: average steps =  190.4
39 Episode: Finished after 148 steps: average steps =  185.2
40 Episode: Finished after 108 steps: average steps =  176.0
41 Episode: Finished after 200 steps: average steps =  176.0
42 Episode: Finished after 200 steps: average steps =  176.0
43 Episode: Finished after 150 steps: average steps =  171.0
44 Episode: Finished after 102 steps: average steps =  161.2
45 Episode: Finished after 200 steps: average steps =  170.8
46 Episode: Finished after 200 steps: average steps =  170.8
47 Episode: Finished after 178 steps: average steps =  168.6
48 Episode: Finished after 200 steps: average steps =  168.6
49 Episode: Finished after 200 steps: average steps =  173.8
50 Episode: Finished after 200 steps: average steps =  183.0
51 Episode: Finished after 200 steps: average steps =  183.0
52 Episode: Finished after 200 steps: average steps =  183.0
53 Episode: Finished after 200 steps: average steps =  188.0
54 Episode: Finished after 200 steps: average steps =  197.8
55 Episode: Finished after 200 steps: average steps =  197.8
56 Episode: Finished after 200 steps: average steps =  197.8
57 Episode: Finished after 200 steps: average steps =  200.0
10 consecutive success
58 Episode: Finished after 196 steps: average steps =  199.6
In [16]:
# アニメーション表示する
%matplotlib notebook
import matplotlib.pyplot as plt

#display_frames_as_anim(frames)   # display now
display_frames_as_anim(frames, 'cartpole_video4b.mp4')   # save to file

if is_colab: # copy to google drive
    ! mkdir -p {SAVE_PREFIX}
    ! cp cartpole_video4b.mp4 {SAVE_PREFIX}     # copy to the Google Drive

実行の様子を動画で表示する(HTML)

In [ ]: