Jul/29/2023 by Yoshihisa Nitta
Google Colab 対応コード

In [1]:
# by nitta
import os
is_colab = 'google.colab' in str(get_ipython())   # for Google Colab

if is_colab:
    from google.colab import drive
    drive.mount('/content/drive')
    SAVE_PREFIX='/content/drive/MyDrive/DeepLearning/openai_gym_rt'
else:
    SAVE_PREFIX='.'
Mounted at /content/drive

Reinforced Training (強化学習) と Deep Reinforced Training (深層強化学習)

1. Google Colab

In [2]:
# Check if running on Colab
#is_colab = 'google.colab' in str(get_ipython())   # for Google Colab

# packages
if is_colab:
    !apt update -qq
    !apt upgrade -qq
    !apt install -qq xvfb
    !pip -q install pyvirtualdisplay

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import matplotlib
matplotlib.rcParams['animation.embed_limit'] = 2**32
39 packages can be upgraded. Run 'apt list --upgradable' to see them.
The following packages have been kept back:
  libcudnn8 libcudnn8-dev libnccl-dev libnccl2
The following packages will be upgraded:
  base-files bash binutils binutils-common binutils-x86-64-linux-gnu coreutils cuda-compat-12-2
  cuda-keyring cuda-toolkit-12-config-common cuda-toolkit-config-common dpkg dpkg-dev libbinutils
  libc-bin libctf-nobfd0 libctf0 libdpkg-perl libgnutls30 libldap-2.5-0 libpam-modules
  libpam-modules-bin libpam-runtime libpam0g libperl5.34 libprocps8 libudev1 linux-libc-dev login
  openssl passwd perl perl-base perl-modules-5.34 procps tar
35 upgraded, 0 newly installed, 0 to remove and 4 not upgraded.
Need to get 58.1 MB of archives.
After this operation, 63.5 kB of additional disk space will be used.
Extracting templates from packages: 100%
Preconfiguring packages ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../base-files_12ubuntu4.6_amd64.deb ...
Unpacking base-files (12ubuntu4.6) over (12ubuntu4.4) ...
Setting up base-files (12ubuntu4.6) ...
Installing new version of config file /etc/issue ...
Installing new version of config file /etc/issue.net ...
Installing new version of config file /etc/lsb-release ...
Installing new version of config file /etc/update-motd.d/10-help-text ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../bash_5.1-6ubuntu1.1_amd64.deb ...
Unpacking bash (5.1-6ubuntu1.1) over (5.1-6ubuntu1) ...
Setting up bash (5.1-6ubuntu1.1) ...
update-alternatives: using /usr/share/man/man7/bash-builtins.7.gz to provide /usr/share/man/man7/builtins.7.gz (builtins.7.gz) in auto mode
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../coreutils_8.32-4.1ubuntu1.1_amd64.deb ...
Unpacking coreutils (8.32-4.1ubuntu1.1) over (8.32-4.1ubuntu1) ...
Setting up coreutils (8.32-4.1ubuntu1.1) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../tar_1.34+dfsg-1ubuntu0.1.22.04.2_amd64.deb ...
Unpacking tar (1.34+dfsg-1ubuntu0.1.22.04.2) over (1.34+dfsg-1ubuntu0.1.22.04.1) ...
Setting up tar (1.34+dfsg-1ubuntu0.1.22.04.2) ...
update-alternatives: warning: forcing reinstallation of alternative /usr/sbin/rmt-tar because link group rmt is broken
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../dpkg_1.21.1ubuntu2.3_amd64.deb ...
Unpacking dpkg (1.21.1ubuntu2.3) over (1.21.1ubuntu2.2) ...
Setting up dpkg (1.21.1ubuntu2.3) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../login_1%3a4.8.1-2ubuntu2.2_amd64.deb ...
Unpacking login (1:4.8.1-2ubuntu2.2) over (1:4.8.1-2ubuntu2.1) ...
Setting up login (1:4.8.1-2ubuntu2.2) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libperl5.34_5.34.0-3ubuntu1.3_amd64.deb ...
Unpacking libperl5.34:amd64 (5.34.0-3ubuntu1.3) over (5.34.0-3ubuntu1.2) ...
Preparing to unpack .../perl_5.34.0-3ubuntu1.3_amd64.deb ...
Unpacking perl (5.34.0-3ubuntu1.3) over (5.34.0-3ubuntu1.2) ...
Preparing to unpack .../perl-base_5.34.0-3ubuntu1.3_amd64.deb ...
Unpacking perl-base (5.34.0-3ubuntu1.3) over (5.34.0-3ubuntu1.2) ...
Setting up perl-base (5.34.0-3ubuntu1.3) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../perl-modules-5.34_5.34.0-3ubuntu1.3_all.deb ...
Unpacking perl-modules-5.34 (5.34.0-3ubuntu1.3) over (5.34.0-3ubuntu1.2) ...
Preparing to unpack .../libc-bin_2.35-0ubuntu3.6_amd64.deb ...
Unpacking libc-bin (2.35-0ubuntu3.6) over (2.35-0ubuntu3.4) ...
Setting up libc-bin (2.35-0ubuntu3.6) ...
/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link

(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libpam0g_1.4.0-11ubuntu2.4_amd64.deb ...
Unpacking libpam0g:amd64 (1.4.0-11ubuntu2.4) over (1.4.0-11ubuntu2.3) ...
Setting up libpam0g:amd64 (1.4.0-11ubuntu2.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libpam-modules-bin_1.4.0-11ubuntu2.4_amd64.deb ...
Unpacking libpam-modules-bin (1.4.0-11ubuntu2.4) over (1.4.0-11ubuntu2.3) ...
Setting up libpam-modules-bin (1.4.0-11ubuntu2.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libpam-modules_1.4.0-11ubuntu2.4_amd64.deb ...
Unpacking libpam-modules:amd64 (1.4.0-11ubuntu2.4) over (1.4.0-11ubuntu2.3) ...
Setting up libpam-modules:amd64 (1.4.0-11ubuntu2.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libpam-runtime_1.4.0-11ubuntu2.4_all.deb ...
Unpacking libpam-runtime (1.4.0-11ubuntu2.4) over (1.4.0-11ubuntu2.3) ...
Setting up libpam-runtime (1.4.0-11ubuntu2.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libudev1_249.11-0ubuntu3.12_amd64.deb ...
Unpacking libudev1:amd64 (249.11-0ubuntu3.12) over (249.11-0ubuntu3.10) ...
Setting up libudev1:amd64 (249.11-0ubuntu3.12) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../passwd_1%3a4.8.1-2ubuntu2.2_amd64.deb ...
Unpacking passwd (1:4.8.1-2ubuntu2.2) over (1:4.8.1-2ubuntu2.1) ...
Setting up passwd (1:4.8.1-2ubuntu2.2) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libgnutls30_3.7.3-4ubuntu1.4_amd64.deb ...
Unpacking libgnutls30:amd64 (3.7.3-4ubuntu1.4) over (3.7.3-4ubuntu1.2) ...
Setting up libgnutls30:amd64 (3.7.3-4ubuntu1.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../00-libprocps8_2%3a3.3.17-6ubuntu2.1_amd64.deb ...
Unpacking libprocps8:amd64 (2:3.3.17-6ubuntu2.1) over (2:3.3.17-6ubuntu2) ...
Preparing to unpack .../01-procps_2%3a3.3.17-6ubuntu2.1_amd64.deb ...
Unpacking procps (2:3.3.17-6ubuntu2.1) over (2:3.3.17-6ubuntu2) ...
Preparing to unpack .../02-openssl_3.0.2-0ubuntu1.15_amd64.deb ...
Unpacking openssl (3.0.2-0ubuntu1.15) over (3.0.2-0ubuntu1.12) ...
Preparing to unpack .../03-libctf0_2.38-4ubuntu2.6_amd64.deb ...
Unpacking libctf0:amd64 (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../04-libctf-nobfd0_2.38-4ubuntu2.6_amd64.deb ...
Unpacking libctf-nobfd0:amd64 (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../05-binutils-x86-64-linux-gnu_2.38-4ubuntu2.6_amd64.deb ...
Unpacking binutils-x86-64-linux-gnu (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../06-libbinutils_2.38-4ubuntu2.6_amd64.deb ...
Unpacking libbinutils:amd64 (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../07-binutils_2.38-4ubuntu2.6_amd64.deb ...
Unpacking binutils (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../08-binutils-common_2.38-4ubuntu2.6_amd64.deb ...
Unpacking binutils-common:amd64 (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../09-cuda-compat-12-2_535.161.08-1_amd64.deb ...
Unpacking cuda-compat-12-2 (535.161.08-1) over (535.129.03-1) ...
Preparing to unpack .../10-cuda-keyring_1.1-1_all.deb ...
Unpacking cuda-keyring (1.1-1) over (1.0-1) ...
Preparing to unpack .../11-cuda-toolkit-12-config-common_12.4.99-1_all.deb ...
Unpacking cuda-toolkit-12-config-common (12.4.99-1) over (12.3.52-1) ...
Preparing to unpack .../12-cuda-toolkit-config-common_12.4.99-1_all.deb ...
Unpacking cuda-toolkit-config-common (12.4.99-1) over (12.3.52-1) ...
Preparing to unpack .../13-dpkg-dev_1.21.1ubuntu2.3_all.deb ...
Unpacking dpkg-dev (1.21.1ubuntu2.3) over (1.21.1ubuntu2.2) ...
Preparing to unpack .../14-libdpkg-perl_1.21.1ubuntu2.3_all.deb ...
Unpacking libdpkg-perl (1.21.1ubuntu2.3) over (1.21.1ubuntu2.2) ...
Preparing to unpack .../15-libldap-2.5-0_2.5.17+dfsg-0ubuntu0.22.04.1_amd64.deb ...
Unpacking libldap-2.5-0:amd64 (2.5.17+dfsg-0ubuntu0.22.04.1) over (2.5.16+dfsg-0ubuntu0.22.04.1) ...
Preparing to unpack .../16-linux-libc-dev_5.15.0-101.111_amd64.deb ...
Unpacking linux-libc-dev:amd64 (5.15.0-101.111) over (5.15.0-88.98) ...
Setting up cuda-toolkit-config-common (12.4.99-1) ...
Setting up cuda-compat-12-2 (535.161.08-1) ...
Setting up binutils-common:amd64 (2.38-4ubuntu2.6) ...
Setting up linux-libc-dev:amd64 (5.15.0-101.111) ...
Setting up libctf-nobfd0:amd64 (2.38-4ubuntu2.6) ...
Setting up perl-modules-5.34 (5.34.0-3ubuntu1.3) ...
Setting up libldap-2.5-0:amd64 (2.5.17+dfsg-0ubuntu0.22.04.1) ...
Setting up cuda-keyring (1.1-1) ...
Setting up libbinutils:amd64 (2.38-4ubuntu2.6) ...
Setting up openssl (3.0.2-0ubuntu1.15) ...
Setting up cuda-toolkit-12-config-common (12.4.99-1) ...
Setting up libprocps8:amd64 (2:3.3.17-6ubuntu2.1) ...
Setting up libctf0:amd64 (2.38-4ubuntu2.6) ...
Setting up libperl5.34:amd64 (5.34.0-3ubuntu1.3) ...
Setting up perl (5.34.0-3ubuntu1.3) ...
Setting up libdpkg-perl (1.21.1ubuntu2.3) ...
Setting up procps (2:3.3.17-6ubuntu2.1) ...
Setting up binutils-x86-64-linux-gnu (2.38-4ubuntu2.6) ...
Setting up binutils (2.38-4ubuntu2.6) ...
Setting up dpkg-dev (1.21.1ubuntu2.3) ...
Processing triggers for man-db (2.10.2-1) ...
Processing triggers for libc-bin (2.35-0ubuntu3.6) ...
/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link

The following additional packages will be installed:
  libfontenc1 libxfont2 libxkbfile1 x11-xkb-utils xfonts-base xfonts-encodings xfonts-utils
  xserver-common
The following NEW packages will be installed:
  libfontenc1 libxfont2 libxkbfile1 x11-xkb-utils xfonts-base xfonts-encodings xfonts-utils
  xserver-common xvfb
0 upgraded, 9 newly installed, 0 to remove and 4 not upgraded.
Need to get 7,814 kB of archives.
After this operation, 11.9 MB of additional disk space will be used.
Selecting previously unselected package libfontenc1:amd64.
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../0-libfontenc1_1%3a1.1.4-1build3_amd64.deb ...
Unpacking libfontenc1:amd64 (1:1.1.4-1build3) ...
Selecting previously unselected package libxfont2:amd64.
Preparing to unpack .../1-libxfont2_1%3a2.0.5-1build1_amd64.deb ...
Unpacking libxfont2:amd64 (1:2.0.5-1build1) ...
Selecting previously unselected package libxkbfile1:amd64.
Preparing to unpack .../2-libxkbfile1_1%3a1.1.0-1build3_amd64.deb ...
Unpacking libxkbfile1:amd64 (1:1.1.0-1build3) ...
Selecting previously unselected package x11-xkb-utils.
Preparing to unpack .../3-x11-xkb-utils_7.7+5build4_amd64.deb ...
Unpacking x11-xkb-utils (7.7+5build4) ...
Selecting previously unselected package xfonts-encodings.
Preparing to unpack .../4-xfonts-encodings_1%3a1.0.5-0ubuntu2_all.deb ...
Unpacking xfonts-encodings (1:1.0.5-0ubuntu2) ...
Selecting previously unselected package xfonts-utils.
Preparing to unpack .../5-xfonts-utils_1%3a7.7+6build2_amd64.deb ...
Unpacking xfonts-utils (1:7.7+6build2) ...
Selecting previously unselected package xfonts-base.
Preparing to unpack .../6-xfonts-base_1%3a1.0.5_all.deb ...
Unpacking xfonts-base (1:1.0.5) ...
Selecting previously unselected package xserver-common.
Preparing to unpack .../7-xserver-common_2%3a21.1.4-2ubuntu1.7~22.04.8_all.deb ...
Unpacking xserver-common (2:21.1.4-2ubuntu1.7~22.04.8) ...
Selecting previously unselected package xvfb.
Preparing to unpack .../8-xvfb_2%3a21.1.4-2ubuntu1.7~22.04.8_amd64.deb ...
Unpacking xvfb (2:21.1.4-2ubuntu1.7~22.04.8) ...
Setting up libfontenc1:amd64 (1:1.1.4-1build3) ...
Setting up xfonts-encodings (1:1.0.5-0ubuntu2) ...
Setting up libxkbfile1:amd64 (1:1.1.0-1build3) ...
Setting up libxfont2:amd64 (1:2.0.5-1build1) ...
Setting up x11-xkb-utils (7.7+5build4) ...
Setting up xfonts-utils (1:7.7+6build2) ...
Setting up xfonts-base (1:1.0.5) ...
Setting up xserver-common (2:21.1.4-2ubuntu1.7~22.04.8) ...
Setting up xvfb (2:21.1.4-2ubuntu1.7~22.04.8) ...
Processing triggers for man-db (2.10.2-1) ...
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
Processing triggers for libc-bin (2.35-0ubuntu3.6) ...
/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link

In [3]:
# Show multiple images as animation
# Colab compatible
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from IPython import display
import os

import gym
def display_frames_as_anim(frames, filepath=None):
    """
    Displays a list of frames as a gif, with controls
    """
    H, W, _ = frames[0].shape
    fig, ax = plt.subplots(1, 1, figsize=(W/100.0, H/100.0))
    ax.axis('off')
    patch = plt.imshow(frames[0])

    def animate(i):
        display.clear_output(wait=True)
        patch.set_data(frames[i])
        return patch

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50, repeat=False)

    if not filepath is None:
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        anim.save(filepath)

    if is_colab:
        display.display(display.HTML(anim.to_jshtml()))
        #plt.close()
    else:
        plt.show()

2. OpenAI Gym : CartPole

In [4]:
import gym

ENV = 'CartPole-v1'
/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.
  and should_run_async(code)
In [5]:
# ランダムに行動する
np.random.seed(12345)

env = gym.make(ENV, new_step_api=True, render_mode='rgb_array') # new_step_api, render_mode
observation = env.reset()

frames = [ env.render()[0] ]
states = [ observation ]
for step in range(100):
    action = env.action_space.sample()  ## np.random.choice(2) # 0: left, 1: right  ##
    obserbation, reward, done, truncated, info = env.step(action) # when new_step_api==True, two bool (done, truncated) returned
    frames.append(env.render()[0])  # env.render(mode='rgb_array') -> env.render()[0]
    states.append(observation)

    if done:
        break

print(len(states))
print(states)
22
[array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32), array([-0.01051384, -0.01437777,  0.0008383 ,  0.0220046 ], dtype=float32)]
/usr/local/lib/python3.10/dist-packages/gym/utils/passive_env_checker.py:241: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)
  if not isinstance(terminated, (bool, np.bool8)):
In [6]:
# アニメーション表示する
%matplotlib notebook
import matplotlib.pyplot as plt

#display_frames_as_anim(frames)   # display now
display_frames_as_anim(frames, 'cartpole_video1b.mp4')   # save to file

if is_colab: # copy to google drive
    ! mkdir -p {SAVE_PREFIX}
    ! cp cartpole_video2a.mp4 {SAVE_PREFIX}     # copy to the Google Drive
cp: cannot stat 'cartpole_video2a.mp4': No such file or directory

実行の様子を動画で表示する(HTML)

3. 深層強化学習

ディープラーニングを使わない従来の強化学習として3通りの手法があった。

  • Q学習
  • Sarsa
  • 方策反復法

このうち、Q学習における行動価値関数 $Q(s,a)$ をディープ・ニューラルネットワークで実現する DQN が提案された。

In [7]:
# CartPole で観測した状態変数の値を名前をつけて保存するために namedtuple を使う
from collections import namedtuple

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.
  and should_run_async(code)
In [8]:
# 定数

GAMMA = 0.99
MAX_STEPS = 200
NUM_EPISODES = 500
In [9]:
# p.130
import random

class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.index = 0

    def push(self, state, action, state_next, reward):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.index] = Transition(state, action, state_next, reward)
        self.index = (self.index + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

6.2 DDQN (Double-DQN)

6.2.1. DDQN の概要

Q学習における行動価値関数 $Q(s,a)$ をディープ・ニューラルネットワークで実現する DQN が提案された。

Double-Q学習 (行動価値関数 $Q$ の更新に $Q$ を使う必要があり、学習が安定しなかったのを、2つのネットワークを使うことで改善した) と DQN を組み合わせて DDQN (Double-DQN) が提案された。

6.2.2. DDQN の実装

In [10]:
# p.149
import torch

class Net(torch.nn.Module):

    def __init__(self, n_in, n_mid, n_out):
        super().__init__()
        self.fc1 = torch.nn.Linear(n_in, n_mid)
        self.fc2 = torch.nn.Linear(n_mid, n_mid)
        self.fc3 = torch.nn.Linear(n_mid, n_out)

    def forward(self, x):
        h1 = torch.nn.functional.relu(self.fc1(x))
        h2 = torch.nn.functional.relu(self.fc2(h1))
        output = self.fc3(h2)
        return output
In [11]:
# Double DQN の実装

BATCH_SIZE = 32
CAPACITY = 10000

class Brain:

    def __init__(self, num_states, num_actions):
        self.num_actions = num_actions
        self.memory = ReplayMemory(CAPACITY)   # 経験を記憶するメモリ

        n_in, n_mid, n_out = num_states, 32, num_actions   # main と target という2つのネットワークを生成する
        self.main_q_network = Net(n_in, n_mid, n_out)
        self.target_q_network = Net(n_in, n_mid, n_out)

        self.optimizer = torch.optim.Adam(   # 最適化
            self.main_q_network.parameters(),
            lr = 0.0001
        )

    def replay(self):
        '''
        Experience Replay で学習する
        '''

        # 1. メモリサイズを確認する
        if len(self.memory) < BATCH_SIZE:
            return

        # 2. mini-batch を作成する
        self.batch, self.state_batch, self.action_batch, self.reward_batch, self.non_final_next_states = self.make_minibatch()

        # 3. 教師信号となる Q(s_t, a_t) を求める
        self.expected_state_action_values = self.get_expected_state_action_values()

        # 4. パラメータを更新する
        self.update_main_q_network()

    def decide_action(self, state, episode):
        # ε-greedy 法で徐々に最適行動の選択を増やす
        epsilon = 0.5 * (1 / (episode + 1))

        if epsilon <= np.random.uniform(0, 1):
            self.main_q_network.eval()   # 推論モードに切り替える
            with torch.no_grad():
                # torch.max(dim=1) は最大値とインデックスのtupleが返される
                # インデックスにアクセスするのは [1] or .index
                # .view(1, 1) により [torch.LongTensor(1,1)] にreshapeされる
                action = self.main_q_network(state).max(dim=1).indices.view(1, 1)

        else:
            action = torch.LongTensor(
                [[ random.randrange(self.num_actions) ]]   # 0 or 1 の乱数
            )

        return action

    def make_minibatch(self):
        '''
        2. mini-batch を作成する
        '''

        # 2.1 メモリから mini-batch 分のデータを取り出す
        transitions = self.memory.sample(BATCH_SIZE)

        # 2.2 各変数を mini-batch に対応する形へ変形する
        # trainsition = [ (state, action, state_next, reward) ...]
        # これを変形して次の形式にする
        # ( [state, ...], [action, ...], [state_next, ...], [reward, ...])
        batch = Transition(*zip(*transitions))

        # 2.3 'Transitions' named tuple から項目を取り出し、
        # concatenate して torch.FloatTensor size(BATCH_SIZE, 1) に変換する
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

        return batch, state_batch, action_batch, reward_batch, non_final_next_states

    def get_expected_state_action_values(self):
        '''
        3. Q(s_t, a_t) を求める
        '''

        # 3.1 推論モードに切り替える
        self.main_q_network.eval()
        self.target_q_network.eval()

        # 3.2 Q(s_t, a_t) を求める。
        # self.model(state_batch) は各action(左 or 右)に対するQ値を全部返すので
        # 返り値は "torch.FloatTensor size (BATCH_SIZE, 2)"
        # 実行した action a_t に対応する Q 値を gather(dim, index) で引っ張り出す。
        self.state_action_values = self.main_q_network(
            self.state_batch
        ).gather(1, self.action_batch)

        # 3.3 次の状態がある index のmax{Q(s_{t+1}, a)} 値を求める。
        # cartpole が done になっておらず、next_state があるかをチェックするインデックスマスクを作成する
        non_final_mask = torch.ByteTensor(
            tuple(map(lambda s: s is not None, self.batch.next_state))
            )
        next_state_values = torch.zeros(BATCH_SIZE)  # initial value is 0
        a_m = torch.zeros(BATCH_SIZE).type(torch.LongTensor)

        # 次の状態での最大の行動 a_m を main_q_network から求める。
        # 出力にアクセスし、max(dim=1)で列方向の最大値の [値, index] を求める
        # (注意) PyTorch で Tensor 配列の最大値を求めるには torch.max(dim=n) を使う
        # dimを指定した場合は値とインデックスのtupleが返されるので
        # [1] or .indices としてそのインデックスの値 (index=1 ) を出力する
        # detach でその値を取り出す
        a_m[non_final_mask] = self.main_q_network(
            self.non_final_next_states
        ).detach().max(dim=1).indices

        # 次の状態のあるものだけフィルターし、size 32 を (32, 1) へ変形する
        a_m_non_final_next_states = a_m[non_final_mask].view(-1, 1)

        # 次の状態がある index の、行動 a_m の Q 値を target Q-Network から求める
        # detach() で取り出す。
        # squeeze で (minibatch, 1) を (minibatch,) へ
        next_state_values[non_final_mask] = self.target_q_network(
            self.non_final_next_states
        ).gather(1, a_m_non_final_next_states).detach().squeeze()

        # 3.4 教師となる Q(s_t, a_t) 値を、Q学習の式から求める。
        expected_state_action_values = self.reward_batch + GAMMA * next_state_values

        return expected_state_action_values

    def update_main_q_network(self):
        '''
        4. パラメータの更新
        '''

        # 4.1 訓練モードに切り替える
        self.main_q_network.train()

        # 4.2 損失関数 (smooth_l1_loss は Huberloss)
        # expected_state_action_values は形状が (minibatch,)なので、
        # unsqueezeで形状を (minibatch, 1) へ
        loss = torch.nn.functional.smooth_l1_loss(
            self.state_action_values,
            self.expected_state_action_values.unsqueeze(1)
        )

        # 4.3 パラメータを更新する
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_q_network(self):
        '''
        target_q_network を main_q_network と同じにする
        '''
        self.target_q_network.load_state_dict(self.main_q_network.state_dict())
/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.
  and should_run_async(code)
In [12]:
class Agent:
    def __init__(self, num_states, num_actions):
        self.brain = Brain(num_states, num_actions)

    def update_q_function(self):
        self.brain.replay()

    def get_action(self, state, episode):
        action = self.brain.decide_action(state, episode)
        return action

    def memorize(self, state, action, state_next, reward):
        self.brain.memory.push(state, action, state_next, reward)

    ##### ADDED
    def update_target_q_function(self):
        self.brain.update_target_q_network()
In [13]:
# [自分へのメモ] env.step() の返り値の変更、env.render() の引数と返り値の変更に対応した
# run() の中で動画(frames)を表示せず、frames を返すように変更した

class Environment:

    def __init__(self):
        self.env = gym.make(ENV, new_step_api=True, render_mode='rgb_array')
        self.num_states = self.env.observation_space.shape[0]
        self.num_actions = self.env.action_space.n
        self.agent = Agent(self.num_states, self.num_actions)

    def run(self):
        episode_10_list = np.zeros(10)  # 10 試行分の立ち続けた step 数を格納し、平均ステップ数を出力に利用
        complete_episodes = 0   # 195 steps 以上立ち続けた試行数
        episode_final = False   # 最後の試行であるか
        frames = []             # 最後の試行を動画にするための配列

        for episode in range(NUM_EPISODES):
            observation = self.env.reset()

            state = observation   # 観測をそのまま状態 s として使用する
            state = torch.from_numpy(state).type(torch.FloatTensor)
            state = torch.unsqueeze(state, dim=0)  # FloatTensor size(4) -> size(1,4)

            for step in range(MAX_STEPS):

                if episode_final is True:
                    frames.append(self.env.render()[0])

                action = self.agent.get_action(state, episode)

                # 行動 a_t の実行により、s_{t+1} と done フラグを求める。
                # (注) new_api=True なのでtruncated も見るべし。

                observation_next, _, done, truncated, _ = self.env.step(action.item())

                if done or truncated:
                    state_next = None  # 次の状態は無い
                    episode_10_list = np.hstack((episode_10_list[1:], step+1))

                    if step < 195:
                        reward = torch.FloatTensor([-1.0])
                        complete_episodes = 0
                    else:
                        reward = torch.FloatTensor([1.0])
                        complete_episodes = complete_episodes + 1

                elif step == MAX_STEPS - 1:
                    state_next = None
                    episode_10_list = np.hstack((episode_10_list[1:], step+1))
                    reward = torch.FloatTensor([1.0])
                    complete_episodes += 1
                    truncated = True

                else:
                    reward = torch.FloatTensor([0.0])
                    state_next = observation_next
                    state_next = torch.from_numpy(state_next).type(torch.FloatTensor)
                    # FloatTensor size (4) --> (1, 4)
                    state_next = torch.unsqueeze(state_next, dim=0)

                # メモリに経験を追加
                self.agent.memorize(state, action, state_next, reward)

                # Experience Replay で Q 関数を更新する
                self.agent.update_q_function()

                # 観測の更新
                state = state_next

                # 終了時の処理
                if done or truncated:
                    print(f'{episode} Episode: Finished after {step+1} steps: average steps = {episode_10_list.mean(): .1f}')

                    #########################
                    # Double-DQN により追加
                    #########################
                    # 2試行に一度 target_q_network を main_q_network と同じにする
                    if (episode % 2 == 0):
                        self.agent.update_target_q_function()

                    break

            if episode_final is True:
                # display_frames_as_gif(frames)
                break

            # 10 連続で 200 steps 立ち続けたら成功
            if complete_episodes >= 10:
                print('10 consecutive success')
                episode_final = True

        return frames
In [14]:
# main クラス
cartpole_env = Environment()
frames = cartpole_env.run()
/usr/local/lib/python3.10/dist-packages/gym/utils/passive_env_checker.py:241: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)
  if not isinstance(terminated, (bool, np.bool8)):
0 Episode: Finished after 12 steps: average steps =  1.2
1 Episode: Finished after 9 steps: average steps =  2.1
2 Episode: Finished after 9 steps: average steps =  3.0
<ipython-input-11-01bb0eafb206>:112: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at ../aten/src/ATen/native/IndexingUtils.h:27.)
  a_m[non_final_mask] = self.main_q_network(
<ipython-input-11-01bb0eafb206>:117: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at ../aten/src/ATen/native/IndexingUtils.h:27.)
  a_m_non_final_next_states = a_m[non_final_mask].view(-1, 1)
<ipython-input-11-01bb0eafb206>:122: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at ../aten/src/ATen/native/IndexingUtils.h:27.)
  next_state_values[non_final_mask] = self.target_q_network(
3 Episode: Finished after 11 steps: average steps =  4.1
4 Episode: Finished after 10 steps: average steps =  5.1
5 Episode: Finished after 9 steps: average steps =  6.0
6 Episode: Finished after 10 steps: average steps =  7.0
7 Episode: Finished after 8 steps: average steps =  7.8
8 Episode: Finished after 9 steps: average steps =  8.7
9 Episode: Finished after 10 steps: average steps =  9.7
10 Episode: Finished after 8 steps: average steps =  9.3
11 Episode: Finished after 10 steps: average steps =  9.4
12 Episode: Finished after 9 steps: average steps =  9.4
13 Episode: Finished after 11 steps: average steps =  9.4
14 Episode: Finished after 10 steps: average steps =  9.4
15 Episode: Finished after 9 steps: average steps =  9.4
16 Episode: Finished after 11 steps: average steps =  9.5
17 Episode: Finished after 10 steps: average steps =  9.7
18 Episode: Finished after 13 steps: average steps =  10.1
19 Episode: Finished after 13 steps: average steps =  10.4
20 Episode: Finished after 9 steps: average steps =  10.5
21 Episode: Finished after 9 steps: average steps =  10.4
22 Episode: Finished after 10 steps: average steps =  10.5
23 Episode: Finished after 11 steps: average steps =  10.5
24 Episode: Finished after 11 steps: average steps =  10.6
25 Episode: Finished after 10 steps: average steps =  10.7
26 Episode: Finished after 48 steps: average steps =  14.4
27 Episode: Finished after 38 steps: average steps =  17.2
28 Episode: Finished after 32 steps: average steps =  19.1
29 Episode: Finished after 13 steps: average steps =  19.1
30 Episode: Finished after 12 steps: average steps =  19.4
31 Episode: Finished after 12 steps: average steps =  19.7
32 Episode: Finished after 15 steps: average steps =  20.2
33 Episode: Finished after 14 steps: average steps =  20.5
34 Episode: Finished after 13 steps: average steps =  20.7
35 Episode: Finished after 13 steps: average steps =  21.0
36 Episode: Finished after 11 steps: average steps =  17.3
37 Episode: Finished after 10 steps: average steps =  14.5
38 Episode: Finished after 11 steps: average steps =  12.4
39 Episode: Finished after 9 steps: average steps =  12.0
40 Episode: Finished after 11 steps: average steps =  11.9
41 Episode: Finished after 9 steps: average steps =  11.6
42 Episode: Finished after 8 steps: average steps =  10.9
43 Episode: Finished after 11 steps: average steps =  10.6
44 Episode: Finished after 9 steps: average steps =  10.2
45 Episode: Finished after 9 steps: average steps =  9.8
46 Episode: Finished after 10 steps: average steps =  9.7
47 Episode: Finished after 8 steps: average steps =  9.5
48 Episode: Finished after 10 steps: average steps =  9.4
49 Episode: Finished after 11 steps: average steps =  9.6
50 Episode: Finished after 10 steps: average steps =  9.5
51 Episode: Finished after 11 steps: average steps =  9.7
52 Episode: Finished after 10 steps: average steps =  9.9
53 Episode: Finished after 9 steps: average steps =  9.7
54 Episode: Finished after 9 steps: average steps =  9.7
55 Episode: Finished after 10 steps: average steps =  9.8
56 Episode: Finished after 9 steps: average steps =  9.7
57 Episode: Finished after 8 steps: average steps =  9.7
58 Episode: Finished after 8 steps: average steps =  9.5
59 Episode: Finished after 9 steps: average steps =  9.3
60 Episode: Finished after 11 steps: average steps =  9.4
61 Episode: Finished after 8 steps: average steps =  9.1
62 Episode: Finished after 10 steps: average steps =  9.1
63 Episode: Finished after 8 steps: average steps =  9.0
64 Episode: Finished after 10 steps: average steps =  9.1
65 Episode: Finished after 10 steps: average steps =  9.1
66 Episode: Finished after 14 steps: average steps =  9.6
67 Episode: Finished after 10 steps: average steps =  9.8
68 Episode: Finished after 13 steps: average steps =  10.3
69 Episode: Finished after 11 steps: average steps =  10.5
70 Episode: Finished after 12 steps: average steps =  10.6
71 Episode: Finished after 15 steps: average steps =  11.3
72 Episode: Finished after 25 steps: average steps =  12.8
73 Episode: Finished after 57 steps: average steps =  17.7
74 Episode: Finished after 51 steps: average steps =  21.8
75 Episode: Finished after 53 steps: average steps =  26.1
76 Episode: Finished after 69 steps: average steps =  31.6
77 Episode: Finished after 111 steps: average steps =  41.7
78 Episode: Finished after 88 steps: average steps =  49.2
79 Episode: Finished after 48 steps: average steps =  52.9
80 Episode: Finished after 127 steps: average steps =  64.4
81 Episode: Finished after 92 steps: average steps =  72.1
82 Episode: Finished after 53 steps: average steps =  74.9
83 Episode: Finished after 58 steps: average steps =  75.0
84 Episode: Finished after 43 steps: average steps =  74.2
85 Episode: Finished after 38 steps: average steps =  72.7
86 Episode: Finished after 81 steps: average steps =  73.9
87 Episode: Finished after 54 steps: average steps =  68.2
88 Episode: Finished after 38 steps: average steps =  63.2
89 Episode: Finished after 65 steps: average steps =  64.9
90 Episode: Finished after 80 steps: average steps =  60.2
91 Episode: Finished after 47 steps: average steps =  55.7
92 Episode: Finished after 65 steps: average steps =  56.9
93 Episode: Finished after 76 steps: average steps =  58.7
94 Episode: Finished after 92 steps: average steps =  63.6
95 Episode: Finished after 36 steps: average steps =  63.4
96 Episode: Finished after 32 steps: average steps =  58.5
97 Episode: Finished after 37 steps: average steps =  56.8
98 Episode: Finished after 39 steps: average steps =  56.9
99 Episode: Finished after 30 steps: average steps =  53.4
100 Episode: Finished after 117 steps: average steps =  57.1
101 Episode: Finished after 39 steps: average steps =  56.3
102 Episode: Finished after 51 steps: average steps =  54.9
103 Episode: Finished after 31 steps: average steps =  50.4
104 Episode: Finished after 65 steps: average steps =  47.7
105 Episode: Finished after 47 steps: average steps =  48.8
106 Episode: Finished after 48 steps: average steps =  50.4
107 Episode: Finished after 43 steps: average steps =  51.0
108 Episode: Finished after 41 steps: average steps =  51.2
109 Episode: Finished after 58 steps: average steps =  54.0
110 Episode: Finished after 43 steps: average steps =  46.6
111 Episode: Finished after 51 steps: average steps =  47.8
112 Episode: Finished after 61 steps: average steps =  48.8
113 Episode: Finished after 53 steps: average steps =  51.0
114 Episode: Finished after 57 steps: average steps =  50.2
115 Episode: Finished after 66 steps: average steps =  52.1
116 Episode: Finished after 55 steps: average steps =  52.8
117 Episode: Finished after 37 steps: average steps =  52.2
118 Episode: Finished after 64 steps: average steps =  54.5
119 Episode: Finished after 45 steps: average steps =  53.2
120 Episode: Finished after 52 steps: average steps =  54.1
121 Episode: Finished after 81 steps: average steps =  57.1
122 Episode: Finished after 63 steps: average steps =  57.3
123 Episode: Finished after 79 steps: average steps =  59.9
124 Episode: Finished after 91 steps: average steps =  63.3
125 Episode: Finished after 40 steps: average steps =  60.7
126 Episode: Finished after 71 steps: average steps =  62.3
127 Episode: Finished after 39 steps: average steps =  62.5
128 Episode: Finished after 71 steps: average steps =  63.2
129 Episode: Finished after 65 steps: average steps =  65.2
130 Episode: Finished after 150 steps: average steps =  75.0
131 Episode: Finished after 41 steps: average steps =  71.0
132 Episode: Finished after 49 steps: average steps =  69.6
133 Episode: Finished after 55 steps: average steps =  67.2
134 Episode: Finished after 53 steps: average steps =  63.4
135 Episode: Finished after 41 steps: average steps =  63.5
136 Episode: Finished after 108 steps: average steps =  67.2
137 Episode: Finished after 66 steps: average steps =  69.9
138 Episode: Finished after 104 steps: average steps =  73.2
139 Episode: Finished after 200 steps: average steps =  86.7
140 Episode: Finished after 43 steps: average steps =  76.0
141 Episode: Finished after 106 steps: average steps =  82.5
142 Episode: Finished after 46 steps: average steps =  82.2
143 Episode: Finished after 130 steps: average steps =  89.7
144 Episode: Finished after 143 steps: average steps =  98.7
145 Episode: Finished after 83 steps: average steps =  102.9
146 Episode: Finished after 131 steps: average steps =  105.2
147 Episode: Finished after 142 steps: average steps =  112.8
148 Episode: Finished after 132 steps: average steps =  115.6
149 Episode: Finished after 131 steps: average steps =  108.7
150 Episode: Finished after 99 steps: average steps =  114.3
151 Episode: Finished after 154 steps: average steps =  119.1
152 Episode: Finished after 91 steps: average steps =  123.6
153 Episode: Finished after 87 steps: average steps =  119.3
154 Episode: Finished after 88 steps: average steps =  113.8
155 Episode: Finished after 87 steps: average steps =  114.2
156 Episode: Finished after 92 steps: average steps =  110.3
157 Episode: Finished after 91 steps: average steps =  105.2
158 Episode: Finished after 124 steps: average steps =  104.4
159 Episode: Finished after 115 steps: average steps =  102.8
160 Episode: Finished after 94 steps: average steps =  102.3
161 Episode: Finished after 128 steps: average steps =  99.7
162 Episode: Finished after 96 steps: average steps =  100.2
163 Episode: Finished after 91 steps: average steps =  100.6
164 Episode: Finished after 118 steps: average steps =  103.6
165 Episode: Finished after 129 steps: average steps =  107.8
166 Episode: Finished after 149 steps: average steps =  113.5
167 Episode: Finished after 149 steps: average steps =  119.3
168 Episode: Finished after 200 steps: average steps =  126.9
169 Episode: Finished after 200 steps: average steps =  135.4
170 Episode: Finished after 200 steps: average steps =  146.0
171 Episode: Finished after 200 steps: average steps =  153.2
172 Episode: Finished after 200 steps: average steps =  163.6
173 Episode: Finished after 200 steps: average steps =  174.5
174 Episode: Finished after 200 steps: average steps =  182.7
175 Episode: Finished after 200 steps: average steps =  189.8
176 Episode: Finished after 200 steps: average steps =  194.9
177 Episode: Finished after 200 steps: average steps =  200.0
10 consecutive success
178 Episode: Finished after 200 steps: average steps =  200.0
In [15]:
# アニメーション表示する
%matplotlib notebook
import matplotlib.pyplot as plt

#display_frames_as_anim(frames)   # display now
display_frames_as_anim(frames, 'cartpole_video2b.mp4')   # save to file

if is_colab: # copy to google drive
    ! mkdir -p {SAVE_PREFIX}
    ! cp cartpole_video2b.mp4 {SAVE_PREFIX}     # copy to the Google Drive

実行の様子を動画で表示する(HTML)

In [ ]: