Jul/30/2023 updated by Yoshihisa Nitta
Feb/03/2023 written by Yoshihisa Nitta
Google Colab 対応コード

In [1]:
# by nitta
import os
is_colab = 'google.colab' in str(get_ipython())   # for Google Colab

if is_colab:
    from google.colab import drive
    drive.mount('/content/drive')
    SAVE_PREFIX='/content/drive/MyDrive/DeepLearning/openai_gym_rt'
else:
    SAVE_PREFIX='.'
Mounted at /content/drive

Reinforced Training (強化学習) と Deep Reinforced Training (深層強化学習)

1. Google Colab

In [2]:
# Check if running on Colab
#is_colab = 'google.colab' in str(get_ipython())   # for Google Colab

# packages
if is_colab:
    !apt update -qq
    !apt upgrade -qq
    !apt install -qq xvfb
    !pip -q install pyvirtualdisplay

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import matplotlib
matplotlib.rcParams['animation.embed_limit'] = 2**32
39 packages can be upgraded. Run 'apt list --upgradable' to see them.
The following packages have been kept back:
  libcudnn8 libcudnn8-dev libnccl-dev libnccl2
The following packages will be upgraded:
  base-files bash binutils binutils-common binutils-x86-64-linux-gnu coreutils cuda-compat-12-2
  cuda-keyring cuda-toolkit-12-config-common cuda-toolkit-config-common dpkg dpkg-dev libbinutils
  libc-bin libctf-nobfd0 libctf0 libdpkg-perl libgnutls30 libldap-2.5-0 libpam-modules
  libpam-modules-bin libpam-runtime libpam0g libperl5.34 libprocps8 libudev1 linux-libc-dev login
  openssl passwd perl perl-base perl-modules-5.34 procps tar
35 upgraded, 0 newly installed, 0 to remove and 4 not upgraded.
Need to get 58.1 MB of archives.
After this operation, 63.5 kB of additional disk space will be used.
Extracting templates from packages: 100%
Preconfiguring packages ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../base-files_12ubuntu4.6_amd64.deb ...
Unpacking base-files (12ubuntu4.6) over (12ubuntu4.4) ...
Setting up base-files (12ubuntu4.6) ...
Installing new version of config file /etc/issue ...
Installing new version of config file /etc/issue.net ...
Installing new version of config file /etc/lsb-release ...
Installing new version of config file /etc/update-motd.d/10-help-text ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../bash_5.1-6ubuntu1.1_amd64.deb ...
Unpacking bash (5.1-6ubuntu1.1) over (5.1-6ubuntu1) ...
Setting up bash (5.1-6ubuntu1.1) ...
update-alternatives: using /usr/share/man/man7/bash-builtins.7.gz to provide /usr/share/man/man7/builtins.7.gz (builtins.7.gz) in auto mode
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../coreutils_8.32-4.1ubuntu1.1_amd64.deb ...
Unpacking coreutils (8.32-4.1ubuntu1.1) over (8.32-4.1ubuntu1) ...
Setting up coreutils (8.32-4.1ubuntu1.1) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../tar_1.34+dfsg-1ubuntu0.1.22.04.2_amd64.deb ...
Unpacking tar (1.34+dfsg-1ubuntu0.1.22.04.2) over (1.34+dfsg-1ubuntu0.1.22.04.1) ...
Setting up tar (1.34+dfsg-1ubuntu0.1.22.04.2) ...
update-alternatives: warning: forcing reinstallation of alternative /usr/sbin/rmt-tar because link group rmt is broken
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../dpkg_1.21.1ubuntu2.3_amd64.deb ...
Unpacking dpkg (1.21.1ubuntu2.3) over (1.21.1ubuntu2.2) ...
Setting up dpkg (1.21.1ubuntu2.3) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../login_1%3a4.8.1-2ubuntu2.2_amd64.deb ...
Unpacking login (1:4.8.1-2ubuntu2.2) over (1:4.8.1-2ubuntu2.1) ...
Setting up login (1:4.8.1-2ubuntu2.2) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libperl5.34_5.34.0-3ubuntu1.3_amd64.deb ...
Unpacking libperl5.34:amd64 (5.34.0-3ubuntu1.3) over (5.34.0-3ubuntu1.2) ...
Preparing to unpack .../perl_5.34.0-3ubuntu1.3_amd64.deb ...
Unpacking perl (5.34.0-3ubuntu1.3) over (5.34.0-3ubuntu1.2) ...
Preparing to unpack .../perl-base_5.34.0-3ubuntu1.3_amd64.deb ...
Unpacking perl-base (5.34.0-3ubuntu1.3) over (5.34.0-3ubuntu1.2) ...
Setting up perl-base (5.34.0-3ubuntu1.3) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../perl-modules-5.34_5.34.0-3ubuntu1.3_all.deb ...
Unpacking perl-modules-5.34 (5.34.0-3ubuntu1.3) over (5.34.0-3ubuntu1.2) ...
Preparing to unpack .../libc-bin_2.35-0ubuntu3.6_amd64.deb ...
Unpacking libc-bin (2.35-0ubuntu3.6) over (2.35-0ubuntu3.4) ...
Setting up libc-bin (2.35-0ubuntu3.6) ...
/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link

(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libpam0g_1.4.0-11ubuntu2.4_amd64.deb ...
Unpacking libpam0g:amd64 (1.4.0-11ubuntu2.4) over (1.4.0-11ubuntu2.3) ...
Setting up libpam0g:amd64 (1.4.0-11ubuntu2.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libpam-modules-bin_1.4.0-11ubuntu2.4_amd64.deb ...
Unpacking libpam-modules-bin (1.4.0-11ubuntu2.4) over (1.4.0-11ubuntu2.3) ...
Setting up libpam-modules-bin (1.4.0-11ubuntu2.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libpam-modules_1.4.0-11ubuntu2.4_amd64.deb ...
Unpacking libpam-modules:amd64 (1.4.0-11ubuntu2.4) over (1.4.0-11ubuntu2.3) ...
Setting up libpam-modules:amd64 (1.4.0-11ubuntu2.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libpam-runtime_1.4.0-11ubuntu2.4_all.deb ...
Unpacking libpam-runtime (1.4.0-11ubuntu2.4) over (1.4.0-11ubuntu2.3) ...
Setting up libpam-runtime (1.4.0-11ubuntu2.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libudev1_249.11-0ubuntu3.12_amd64.deb ...
Unpacking libudev1:amd64 (249.11-0ubuntu3.12) over (249.11-0ubuntu3.10) ...
Setting up libudev1:amd64 (249.11-0ubuntu3.12) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../passwd_1%3a4.8.1-2ubuntu2.2_amd64.deb ...
Unpacking passwd (1:4.8.1-2ubuntu2.2) over (1:4.8.1-2ubuntu2.1) ...
Setting up passwd (1:4.8.1-2ubuntu2.2) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../libgnutls30_3.7.3-4ubuntu1.4_amd64.deb ...
Unpacking libgnutls30:amd64 (3.7.3-4ubuntu1.4) over (3.7.3-4ubuntu1.2) ...
Setting up libgnutls30:amd64 (3.7.3-4ubuntu1.4) ...
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../00-libprocps8_2%3a3.3.17-6ubuntu2.1_amd64.deb ...
Unpacking libprocps8:amd64 (2:3.3.17-6ubuntu2.1) over (2:3.3.17-6ubuntu2) ...
Preparing to unpack .../01-procps_2%3a3.3.17-6ubuntu2.1_amd64.deb ...
Unpacking procps (2:3.3.17-6ubuntu2.1) over (2:3.3.17-6ubuntu2) ...
Preparing to unpack .../02-openssl_3.0.2-0ubuntu1.15_amd64.deb ...
Unpacking openssl (3.0.2-0ubuntu1.15) over (3.0.2-0ubuntu1.12) ...
Preparing to unpack .../03-libctf0_2.38-4ubuntu2.6_amd64.deb ...
Unpacking libctf0:amd64 (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../04-libctf-nobfd0_2.38-4ubuntu2.6_amd64.deb ...
Unpacking libctf-nobfd0:amd64 (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../05-binutils-x86-64-linux-gnu_2.38-4ubuntu2.6_amd64.deb ...
Unpacking binutils-x86-64-linux-gnu (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../06-libbinutils_2.38-4ubuntu2.6_amd64.deb ...
Unpacking libbinutils:amd64 (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../07-binutils_2.38-4ubuntu2.6_amd64.deb ...
Unpacking binutils (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../08-binutils-common_2.38-4ubuntu2.6_amd64.deb ...
Unpacking binutils-common:amd64 (2.38-4ubuntu2.6) over (2.38-4ubuntu2.3) ...
Preparing to unpack .../09-cuda-compat-12-2_535.161.08-1_amd64.deb ...
Unpacking cuda-compat-12-2 (535.161.08-1) over (535.129.03-1) ...
Preparing to unpack .../10-cuda-keyring_1.1-1_all.deb ...
Unpacking cuda-keyring (1.1-1) over (1.0-1) ...
Preparing to unpack .../11-cuda-toolkit-12-config-common_12.4.99-1_all.deb ...
Unpacking cuda-toolkit-12-config-common (12.4.99-1) over (12.3.52-1) ...
Preparing to unpack .../12-cuda-toolkit-config-common_12.4.99-1_all.deb ...
Unpacking cuda-toolkit-config-common (12.4.99-1) over (12.3.52-1) ...
Preparing to unpack .../13-dpkg-dev_1.21.1ubuntu2.3_all.deb ...
Unpacking dpkg-dev (1.21.1ubuntu2.3) over (1.21.1ubuntu2.2) ...
Preparing to unpack .../14-libdpkg-perl_1.21.1ubuntu2.3_all.deb ...
Unpacking libdpkg-perl (1.21.1ubuntu2.3) over (1.21.1ubuntu2.2) ...
Preparing to unpack .../15-libldap-2.5-0_2.5.17+dfsg-0ubuntu0.22.04.1_amd64.deb ...
Unpacking libldap-2.5-0:amd64 (2.5.17+dfsg-0ubuntu0.22.04.1) over (2.5.16+dfsg-0ubuntu0.22.04.1) ...
Preparing to unpack .../16-linux-libc-dev_5.15.0-101.111_amd64.deb ...
Unpacking linux-libc-dev:amd64 (5.15.0-101.111) over (5.15.0-88.98) ...
Setting up cuda-toolkit-config-common (12.4.99-1) ...
Setting up cuda-compat-12-2 (535.161.08-1) ...
Setting up binutils-common:amd64 (2.38-4ubuntu2.6) ...
Setting up linux-libc-dev:amd64 (5.15.0-101.111) ...
Setting up libctf-nobfd0:amd64 (2.38-4ubuntu2.6) ...
Setting up perl-modules-5.34 (5.34.0-3ubuntu1.3) ...
Setting up libldap-2.5-0:amd64 (2.5.17+dfsg-0ubuntu0.22.04.1) ...
Setting up cuda-keyring (1.1-1) ...
Setting up libbinutils:amd64 (2.38-4ubuntu2.6) ...
Setting up openssl (3.0.2-0ubuntu1.15) ...
Setting up cuda-toolkit-12-config-common (12.4.99-1) ...
Setting up libprocps8:amd64 (2:3.3.17-6ubuntu2.1) ...
Setting up libctf0:amd64 (2.38-4ubuntu2.6) ...
Setting up libperl5.34:amd64 (5.34.0-3ubuntu1.3) ...
Setting up perl (5.34.0-3ubuntu1.3) ...
Setting up libdpkg-perl (1.21.1ubuntu2.3) ...
Setting up procps (2:3.3.17-6ubuntu2.1) ...
Setting up binutils-x86-64-linux-gnu (2.38-4ubuntu2.6) ...
Setting up binutils (2.38-4ubuntu2.6) ...
Setting up dpkg-dev (1.21.1ubuntu2.3) ...
Processing triggers for man-db (2.10.2-1) ...
Processing triggers for libc-bin (2.35-0ubuntu3.6) ...
/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link

The following additional packages will be installed:
  libfontenc1 libxfont2 libxkbfile1 x11-xkb-utils xfonts-base xfonts-encodings xfonts-utils
  xserver-common
The following NEW packages will be installed:
  libfontenc1 libxfont2 libxkbfile1 x11-xkb-utils xfonts-base xfonts-encodings xfonts-utils
  xserver-common xvfb
0 upgraded, 9 newly installed, 0 to remove and 4 not upgraded.
Need to get 7,814 kB of archives.
After this operation, 11.9 MB of additional disk space will be used.
Selecting previously unselected package libfontenc1:amd64.
(Reading database ... 121753 files and directories currently installed.)
Preparing to unpack .../0-libfontenc1_1%3a1.1.4-1build3_amd64.deb ...
Unpacking libfontenc1:amd64 (1:1.1.4-1build3) ...
Selecting previously unselected package libxfont2:amd64.
Preparing to unpack .../1-libxfont2_1%3a2.0.5-1build1_amd64.deb ...
Unpacking libxfont2:amd64 (1:2.0.5-1build1) ...
Selecting previously unselected package libxkbfile1:amd64.
Preparing to unpack .../2-libxkbfile1_1%3a1.1.0-1build3_amd64.deb ...
Unpacking libxkbfile1:amd64 (1:1.1.0-1build3) ...
Selecting previously unselected package x11-xkb-utils.
Preparing to unpack .../3-x11-xkb-utils_7.7+5build4_amd64.deb ...
Unpacking x11-xkb-utils (7.7+5build4) ...
Selecting previously unselected package xfonts-encodings.
Preparing to unpack .../4-xfonts-encodings_1%3a1.0.5-0ubuntu2_all.deb ...
Unpacking xfonts-encodings (1:1.0.5-0ubuntu2) ...
Selecting previously unselected package xfonts-utils.
Preparing to unpack .../5-xfonts-utils_1%3a7.7+6build2_amd64.deb ...
Unpacking xfonts-utils (1:7.7+6build2) ...
Selecting previously unselected package xfonts-base.
Preparing to unpack .../6-xfonts-base_1%3a1.0.5_all.deb ...
Unpacking xfonts-base (1:1.0.5) ...
Selecting previously unselected package xserver-common.
Preparing to unpack .../7-xserver-common_2%3a21.1.4-2ubuntu1.7~22.04.8_all.deb ...
Unpacking xserver-common (2:21.1.4-2ubuntu1.7~22.04.8) ...
Selecting previously unselected package xvfb.
Preparing to unpack .../8-xvfb_2%3a21.1.4-2ubuntu1.7~22.04.8_amd64.deb ...
Unpacking xvfb (2:21.1.4-2ubuntu1.7~22.04.8) ...
Setting up libfontenc1:amd64 (1:1.1.4-1build3) ...
Setting up xfonts-encodings (1:1.0.5-0ubuntu2) ...
Setting up libxkbfile1:amd64 (1:1.1.0-1build3) ...
Setting up libxfont2:amd64 (1:2.0.5-1build1) ...
Setting up x11-xkb-utils (7.7+5build4) ...
Setting up xfonts-utils (1:7.7+6build2) ...
Setting up xfonts-base (1:1.0.5) ...
Setting up xserver-common (2:21.1.4-2ubuntu1.7~22.04.8) ...
Setting up xvfb (2:21.1.4-2ubuntu1.7~22.04.8) ...
Processing triggers for man-db (2.10.2-1) ...
Processing triggers for fontconfig (2.13.1-4.2ubuntu5) ...
Processing triggers for libc-bin (2.35-0ubuntu3.6) ...
/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_0.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbb.so.12 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbbind_2_5.so.3 is not a symbolic link

In [3]:
# Show multiple images as animation
# Colab compatible
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from IPython import display
import os

import gym
def display_frames_as_anim(frames, filepath=None):
    """
    Displays a list of frames as a gif, with controls
    """
    H, W, _ = frames[0].shape
    fig, ax = plt.subplots(1, 1, figsize=(W/100.0, H/100.0))
    ax.axis('off')
    patch = plt.imshow(frames[0])

    def animate(i):
        display.clear_output(wait=True)
        patch.set_data(frames[i])
        return patch

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50, repeat=False)

    if not filepath is None:
        dpath, fname = os.path.split(filepath)
        if dpath != '' and not os.path.exists(dpath):
            os.makedirs(dpath)
        anim.save(filepath)

    if is_colab:
        display.display(display.HTML(anim.to_jshtml()))
        #plt.close()
    else:
        plt.show()

2. OpenAI Gym : CartPole

In [4]:
import gym

ENV = 'CartPole-v1'
/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.
  and should_run_async(code)
In [5]:
# ランダムに行動する
np.random.seed(12345)

env = gym.make(ENV, new_step_api=True, render_mode='rgb_array') # new_step_api, render_mode
observation = env.reset()

frames = [ env.render()[0] ]
states = [ observation ]
for step in range(100):
    action = env.action_space.sample()  ## np.random.choice(2) # 0: left, 1: right  ##
    obserbation, reward, done, truncated, info = env.step(action) # when new_step_api==True, two bool (done, truncated) returned
    frames.append(env.render()[0])  # env.render(mode='rgb_array') -> env.render()[0]
    states.append(observation)

    if done:
        break

print(len(states))
print(states)
14
[array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32), array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32), array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32), array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32), array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32), array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32), array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32), array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32), array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32), array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32), array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32), array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32), array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32), array([-0.04642091, -0.03470421, -0.00835946,  0.0287855 ], dtype=float32)]
/usr/local/lib/python3.10/dist-packages/gym/utils/passive_env_checker.py:241: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)
  if not isinstance(terminated, (bool, np.bool8)):
In [6]:
# アニメーション表示する
%matplotlib notebook
import matplotlib.pyplot as plt

#display_frames_as_anim(frames)   # display now
display_frames_as_anim(frames, 'cartpole_video5a.mp4')   # save to file

if is_colab: # copy to google drive
    ! mkdir -p {SAVE_PREFIX}
    ! cp cartpole_video5a.mp4 {SAVE_PREFIX}     # copy to the Google Drive

実行の様子を動画で表示する(HTML)

3. 深層強化学習

ディープラーニングを使わない従来の強化学習として3通りの手法があった。

  • Q学習
  • Sarsa
  • 方策反復法

このうち、Q学習における行動価値関数 $Q(s,a)$ をディープ・ニューラルネットワークで実現する DQN が提案された。

In [7]:
# CartPole で観測した状態変数の値を名前をつけて保存するために namedtuple を使う
from collections import namedtuple

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
In [8]:
# 定数

GAMMA = 0.99
MAX_STEPS = 200
NUM_EPISODES = 500
In [9]:
# Replay Memory
import random

class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.index = 0

    def push(self, state, action, state_next, reward):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.index] = Transition(state, action, state_next, reward)
        self.index = (self.index + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

A2C

A3C (Asynchronous Advantage Actor-Critic) は3つの工夫を組み合わせたアルゴリズムである。 Asynchronous は、「複数のエージェントを用意し非同期)で分散学習する」ことを意味する。 Advance は、「(1 step だけでなく、) 2 step 以上先までの状態を使用して更新を行う」ことを意味する。 Actor-Critic は、「方策反復法と価値反復法を組み合わせる」ことを意味し、Actor は方策を出力する関数、Critic は価値を出力する関数を表す。

A2C は、A3C から非同期的要素を取り除いたものである。複数のエージェントは1つのニューラルネットワークを共有する。

6.5 A2C の実装

A2C の概要

A2C は A3C という手法から派生した分散学習型の深層強化学習であり、 A2C という名前は、Advantage 学習と Actor-Critic に由来する。 この2つの手法を分散型の強化学習に組み合わせて利用する。

分散学習は、エージェントを複数用意して強化学習を行が、 全エージェントが同じディープ・ニューラルネットワークを共有する。

Advantage 学習

Q学習やDQNではQ関数の更新の際、 $Q(s_t,a_t)$ が $R(t+1) + \gamma \cdot \max [ Q(s_{t+1}, a) ]$ に近づくように、Q関数を学習する。 $Q(s_t, a_t)$ の学習に 1 step 先の行動価値関数の値 $Q(s_{t+1}, a)$ を使う。 Advantage 学習はこのQ関数の更新を 1 step 先でなく、2 step 以上先まで動かして更新する。

2 step 先まで考慮した場合のQ関数の更新式は次の通り。

$Q(s_t, a_t) \rightarrow R(t+1) + \gamma \cdot R(t+2) + (\gamma^2) \cdot \max_a [Q(s_{t+2}, a)]$

ただし、何ステップも先まで Advantage 学修すると、何ステップも最適ではない Q関数で行動を決定して間違った学習をする確率が増えてしまう。 そのため、適度なステップ数で Advantage 学習をするのが一般的である。

Actor-Critic

Q学習は価値反復法の手法であるが、Actor-Critic は方策反復法と価値反復法の両方を使用する。

Actor-Critic のニューラルネットワークでは、入力は DQN と同じく状態変数である。、たとえばCartPole では、「位置」「速度」「角度」「角速度」の4変数が入力となる。

Actor-Critic の出力は Actor と Critic それぞれの出力の集合である。 Actor は行動を出力するので、出力の個数は「行動の種類数」である。 また、Critic は状態価値 $V^{\pi}_{s_t}$を出力するので、出力の個数は1である。 たとえば CartPole では行動は2種類なので、全体の出力数は3個となる。

状態価値 $V^{\pi}_{s_t}$ は、状態 $s_t$ になった場合にその先得られるであろう割引報酬和の期待値である。

Actor-Critic のニューラルネットワークの結合パラメータの学習法

誤差関数を定義する。

Actor 側で最大化したいのは、状態 $s_t$ において、結合パラメータ $\theta$ のニューラルネットワークを使用して行動し続けたときに得られる割引報酬和 $J(\theta , s_t )$ である。 方策勾配法を使うと割引報酬和は次の通り。

$J(\theta , s_t) = \mathbb{E} [ \log \pi_{\theta} (a|s) (Q^{\pi}(s,a) - V^{\pi}_{s})]$

$\mathbb{E}[]$ は期待値を計算するという意味で、実装時にはミニバッチの平均を求める。 $\log \pi_{\theta} (a|s)$ は状態 $s$ のときに行動 $a$ を採用する確率の $\log$ を計算したものである。

$Q^{\pi}(s,a)$ は状態 $s$ で行動 $a$ を採用した場合の行動価値である。 ただし、$Q^{\pi}(s,a)$ は行動 $a$ についての変数ではなく、定数として扱う。 A2C では行動価値を Advantage 学習で計算する。 $V^{\pi}_s$ は状態価値であり、Critic の出力である。

A2C および A3C では、Actor の学習に方策のエントロピー項を追加する。 エントロピー項は次の通り。

$\displaystyle \mbox{Actor}_\mbox{entropy} = \sum^{a} [\pi_{\theta} (a|s) \log \pi_{\theta} (a|s)]$

総和は行動の種類について総和を計算する意味である。 このエントロピー項は方策が行動をランダムに選択する作戦の場合(学習初期)が最大の値となる。 どれか一つの行動しか選択しない方策の場合はエントロピーが最小になる。 エントロピー項を追加することによって、学習初期は学習がゆっくりとなり、局所解に落ちるのを避けている。

Critic 側は状態価値 $V^{\pi}_s$ を正しく出力するように学習したいので、実際に行動して得られた行動価値 $Q^{\pi}(s,a)$ と出力 $V^{\pi}_s$ が一致するように学習する。 損失関数は次の通り。

$\mbox{loss}_\mbox{{Critic}} = (Q^{\pi}(s,a) - V^{\pi}_s)^2$

A2C の実装

In [10]:
# p.177

NUM_PROCESSES = 32     # 同時に実行する環境
NUM_ADVANCED_STEP = 5  # 何ステップ進めて報酬和を計算するのか設定

# A2C の損失関数の計算のための定数設定
value_loss_coef = 0.5
entropy_coef = 0.01
max_grad_norm = 0.5
In [11]:
# p.178
# メモリクラスの定義

class RolloutStorage(object):
    '''
    Advantage 学習をするためのメモリクラス
    '''

    def __init__(self, num_steps, num_processes, obj_shape):
        self.observations = torch.zeros(num_steps + 1, num_processes, 4)
        self.masks = torch.ones(num_steps + 1, num_processes, 1)
        self.rewards = torch.zeros(num_steps, num_processes, 1)
        self.actions = torch.zeros(num_steps, num_processes, 1).long()

        # 割引報酬和を格納
        self.returns = torch.zeros(num_steps + 1, num_processes, 1)
        self.index = 0  # insert する index

    def insert(self, current_obs, action, reward, mask):
        '''
        次のindexにtransactionを格納する
        '''
        self.observations[self.index+1].copy_(current_obs)
        self.masks[self.index+1].copy_(mask)
        self.rewards[self.index].copy_(reward)
        self.actions[self.index].copy_(action)

        self.index = (self.index + 1) % NUM_ADVANCED_STEP   # indexの更新

    def after_update(self):
        '''
        Advantage する step 数が完了したら、最新のものを index0 に格納する
        '''
        self.observations[0].copy_(self.observations[-1])
        self.masks[0].copy_(self.masks[-1])

    def compute_returns(self, next_value):
        '''
        Advantage するステップ中の各ステップの割引報酬和を計算する。
        5 step 目から逆向きに計算する。5step -> Advantage1, 4step -> Advantage2
        '''
        self.returns[-1] = next_value
        for ad_step in reversed(range(self.rewards.size(0))):
            self.returns[ad_step] = self.returns[ad_step + 1] * GAMMA * self.masks[ad_step + 1] + self.rewards[ad_step]
In [12]:
# p.179
# A2C の Deep Neural Network を構築する

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(torch.nn.Module):

    def __init__(self, n_in, n_mid, n_out):
        super().__init__()
        self.fc1 = torch.nn.Linear(n_in, n_mid)
        self.fc2 = torch.nn.Linear(n_mid, n_mid)
        self.actor = torch.nn.Linear(n_mid, n_out)  # Actor
        self.critic = torch.nn.Linear(n_mid, 1)    # Critic

    def forward(self, x):
        h1 = torch.nn.functional.relu(self.fc1(x))
        h2 = torch.nn.functional.relu(self.fc2(h1))
        critic_output = self.critic(h2)  # 状態価値
        actor_output = self.actor(h2)    # 行動

        return critic_output, actor_output

    def act(self, x):
        '''
        状態xから行動を確率的に求める
        '''
        value, actor_output = self(x)   # forward()関数をcallする

        action_probs = torch.nn.functional.softmax(actor_output, dim=1)
        action = action_probs.multinomial(num_samples=1)

        return action

    def get_value(self, x):
        '''
        状態xから状態価値を計算する
        '''
        value, actor_output = self(x)

        return value

    def evaluate_actions(self, x, actions):
        '''
        状態xから状態価値、実際の行動actionsのlog確率とエントロピーを求める
        '''
        value, actor_output = self(x)

        log_probs = torch.nn.functional.log_softmax(actor_output, dim=1)
        action_log_probs = log_probs.gather(1, actions)  # 実際の行動のlog_probsを求める

        probs = torch.nn.functional.softmax(actor_output, dim=1)
        entropy = -(log_probs * probs).sum(-1).mean()

        return value, action_log_probs, entropy
In [13]:
# p.180
# エージェントの頭脳クラスを定義する。全エージェントで共有する

import torch
from torch import optim

class Brain(object):

    def __init__(self, actor_critic):
        self.actor_critic = actor_critic
        self.optimizer = torch.optim.Adam(self.actor_critic.parameters(), lr=0.01)

    def update(self, rollouts):
        '''
        Advantage で計算した5つのstepのすべてを使って更新する
        '''
        obs_shape = rollouts.observations.size()[2:] # torch.Size([4, 84, 84])
        num_steps = NUM_ADVANCED_STEP
        num_processes = NUM_PROCESSES

        values, action_log_probs, entropy = self.actor_critic.evaluate_actions(
            rollouts.observations[:-1].view(-1, 4),
            rollouts.actions.view(-1, 1)
        )

        # rollouts.observations[:-1].view(-1, 4)  torch.Size([80, 4])
        # rollouts.actions.view(-1, 1)            torch.Size([80, 1])
        # values                                  torch.Size([80, 1])
        # action_log_probs                        torch.Size([80, 1])
        # entropy                                 torch.Size([])

        values = values.view(num_steps, num_processes, 1) # torch.Size([5, 32, 1])
        action_log_probs = action_log_probs.view(num_steps, num_processes, 1)

        # advantage (行動価値-状態価値)の計算
        advantages = rollouts.returns[:-1] - values  # torch.Size([5, 32, 1])

        # Critic の loss を計算
        value_loss = advantages.pow(2).mean()

        # Actor の gain を計算する。あとでマイナスをかけて loss にする
        action_gain = (action_log_probs * advantages.detach()).mean()
        # detach して advantages を定数として扱う

        # 誤差関数の総和
        total_loss = (value_loss * value_loss_coef - action_gain - entropy * entropy_coef)

        # 結合パラメータを更新する
        self.actor_critic.train()   # 訓練モードに
        self.optimizer.zero_grad()  # 勾配をリセット
        total_loss.backward()       # back propagation を計算する
        torch.nn.utils.clip_grad_norm(self.actor_critic.parameters(), max_grad_norm)
        # 一気に変化しないように 0.5 でその値を取り出す

        self.optimizer.step()   # 結合パラメータを更新する
/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.
  and should_run_async(code)

Agent クラスは今回は用意せず、Environment クラスでマルチエージェントを取り扱うことにする。

Environment クラスではエージェントを複数生成し、Advantage 学習による報酬の計算も行う。

In [14]:
# p.1782
# [自分へのメモ] env.step() の返り値の変更、env.render() の引数と返り値の変更に対応した
# run() の中で動画(frames)を表示せず、frames を返すように変更した

import copy

class Environment:

    def __init__(self, max_step=200):
        self.max_step = max_step      ## by nitta
        # 同時に実行するエージェント数の環境を生成する
        self.envs = [ gym.make(ENV, new_step_api=True, render_mode='rgb_array') for i in range(NUM_PROCESSES) ]
        # 全エージェントが共有して持つ頭脳 Brain を生成する
        self.n_in = self.envs[0].observation_space.shape[0]  # 状態の数は 4
        self.n_out = self.envs[0].action_space.n  # 行動の数は2
        self.n_mid = 32
        self.actor_critic = Net(self.n_in, self.n_mid, self.n_out)
        self.global_brain = Brain(self.actor_critic)

    def run(self):
        # 格納用変数を生成する
        obs_shape = self.n_in
        current_obs = torch.zeros(NUM_PROCESSES, obs_shape)  # torch.Size([32, 4])
        rollouts = RolloutStorage(
            NUM_ADVANCED_STEP,
            NUM_PROCESSES,
            obs_shape
        )
        episode_rewards = torch.zeros([NUM_PROCESSES, 1])  # 現在の試行の報酬
        final_rewards = torch.zeros([NUM_PROCESSES, 1])  # 最後の試行の報酬
        obs_np = np.zeros([NUM_PROCESSES, obs_shape])
        reward_np = np.zeros([NUM_PROCESSES, 1])
        done_np = np.zeros([NUM_PROCESSES, 1])
        truncated_np = np.zeros([NUM_PROCESSES, 1]) ## by nitta
        each_step = np.zeros(NUM_PROCESSES)  # 各環境のstep数を記録
        episode = 0  # 環境0の試行数

        # 初期状態の開始
        obs = [self.envs[i].reset() for i in range(NUM_PROCESSES)]
        obs = np.array(obs)
        obs = torch.from_numpy(obs).float()  # torch.Size([32, 4])
        current_obs = obs  # 最新のobs

        # advanced 学習用のオブジェクト rollouts の状態の1つ目に、現在の状態を保持
        rollouts.observations[0].copy_(current_obs)

        # 実行ループ
        for j in range(NUM_EPISODES * NUM_PROCESSES):
            # advanced 学習する step 数ごとに計算する
            for step in range(NUM_ADVANCED_STEP):

                # 行動を決める
                with torch.no_grad():
                    action = self.actor_critic.act(rollouts.observations[step])

                # (32, 1) -> (32,) -> tensor を Numpy へ
                actions = action.squeeze(1).numpy()

                # 1 step の実行
                for i in range(NUM_PROCESSES):
                    obs_np[i], reward_np[i], done_np[i], truncated_np[i], _ = self.envs[i].step(actions[i]) ## by nitta

                    if each_step[i]+1 >= self.max_step:    ## by nitta
                        truncated_np[i] = True

                    # episode の終了評価と、state_next を設定する
                    if done_np[i] or truncated_np[i]: ## by nitta

                        # 環境0 のときのみ出力する
                        if i == 0:
                            print(f'{episode} Episode: Finished after {each_step[i]+1} steps')
                            episode += 1

                        # 報酬を設定する
                        if each_step[i] < self.max_step - 5:
                            reward_np[i] = -1.0  # 途中でこけたら罰則として報酬(-1)を与える
                        else:
                            reward_np[i] = 1.0  # 立ったまま修了は報酬(1)を与える

                        each_step[i] = 0  # step数のリセット
                        obs_np[i] = self.envs[i].reset()   # 実行環境のリセット

                    else:
                        reward_np[i] = 0.0   # 普段は報酬0
                        each_step[i] += 1

                # 報酬を tensor に変換し、試行の報酬額に足す
                reward = torch.from_numpy(reward_np).float()
                episode_rewards += reward

                # 各実行環境それぞれについて、done ならmaskは0に、継続中ならmask は1にする
                masks_ = []
                for done_, truncated_ in zip(done_np, truncated_np):
                    if done_ or truncated_:
                        masks_.append([0.0])
                    else:
                        masks_.append([1.0])
                masks = torch.FloatTensor(masks_)
                #masks = torch.FloatTensor(
                #    [[0.0] if done_ or truncated_ else [1.0] for done_, truncated_in zip(done_np, truncated_np)]
                #)

                # 最後の試行の総報酬額を更新する
                final_rewards *= masks  # 継続中は1をかけてそのまま、doneでは0をかけてリセット
                # 継続中は0を足す。doneではepisode_rewardsを足す
                final_rewards += (1 - masks) * episode_rewards

                # 試行の総報酬額を更新する
                episode_rewards *= masks  # 継続中は1をかけてそのまま。doneでは0に

                # 現在の状態をdoneで全部0にする
                current_obs *= masks

                # current_obs を更新する
                obs = torch.from_numpy(obs_np).float()  # torch.Size([32, 4])
                current_obs = obs  # 最新のobsを格納する

                # メモリオブジェクトに現stepのtransitionを挿入する
                rollouts.insert(current_obs, action.data, reward, masks)

            # advanced のfor loop終了

            # advanced した最終stepの状態から予想する状態価値を計算する
            with torch.no_grad():
                next_value = self.actor_critic.get_value(
                    rollouts.observations[-1]  # torch.Size([6, 32, 4])
                ).detach()

            # 全stepの割引報酬和を計算して、rollouts の変数 returns を更新する
            rollouts.compute_returns(next_value)

            # ネットワークとrolloutを更新する
            self.global_brain.update(rollouts)
            rollouts.after_update()

            # 全部のNUM_PROCESSESが200step立ち続けたら成功
            if final_rewards.sum().numpy() >= NUM_PROCESSES:
                print('success')
                break
In [15]:
# main クラス
cartpole_env = Environment(max_step=300)
cartpole_env.run()
/usr/local/lib/python3.10/dist-packages/gym/utils/passive_env_checker.py:241: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)
  if not isinstance(terminated, (bool, np.bool8)):
<ipython-input-13-bdf8e3f576b0>:52: UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_.
  torch.nn.utils.clip_grad_norm(self.actor_critic.parameters(), max_grad_norm)
0 Episode: Finished after 11.0 steps
1 Episode: Finished after 10.0 steps
2 Episode: Finished after 12.0 steps
3 Episode: Finished after 22.0 steps
4 Episode: Finished after 19.0 steps
5 Episode: Finished after 10.0 steps
6 Episode: Finished after 13.0 steps
7 Episode: Finished after 134.0 steps
8 Episode: Finished after 51.0 steps
9 Episode: Finished after 33.0 steps
10 Episode: Finished after 157.0 steps
11 Episode: Finished after 65.0 steps
12 Episode: Finished after 113.0 steps
13 Episode: Finished after 14.0 steps
14 Episode: Finished after 107.0 steps
15 Episode: Finished after 179.0 steps
16 Episode: Finished after 94.0 steps
17 Episode: Finished after 108.0 steps
18 Episode: Finished after 174.0 steps
19 Episode: Finished after 73.0 steps
20 Episode: Finished after 53.0 steps
21 Episode: Finished after 38.0 steps
22 Episode: Finished after 36.0 steps
23 Episode: Finished after 299.0 steps
24 Episode: Finished after 224.0 steps
25 Episode: Finished after 200.0 steps
26 Episode: Finished after 114.0 steps
27 Episode: Finished after 60.0 steps
28 Episode: Finished after 92.0 steps
29 Episode: Finished after 135.0 steps
30 Episode: Finished after 242.0 steps
31 Episode: Finished after 286.0 steps
32 Episode: Finished after 289.0 steps
33 Episode: Finished after 88.0 steps
34 Episode: Finished after 300.0 steps
success
In [16]:
frames = []

observation = cartpole_env.envs[0].reset()

for _ in range(cartpole_env.max_step):
    frames.append(cartpole_env.envs[0].render()[0])
    observation = torch.from_numpy(observation.reshape(-1,4)).float()
    action = cartpole_env.actor_critic.act(observation)
    action = action.squeeze().numpy()
    observation, reward, done, truncated, info = cartpole_env.envs[0].step(action)
    if done or truncated:
        break
/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.
  and should_run_async(code)
In [17]:
# アニメーション表示する
%matplotlib notebook
import matplotlib.pyplot as plt

#display_frames_as_anim(frames)   # display now
display_frames_as_anim(frames, 'cartpole_video5b.mp4')   # save to file

if is_colab: # copy to google drive
    ! mkdir -p {SAVE_PREFIX}
    ! cp cartpole_video5b.mp4 {SAVE_PREFIX}     # copy to the Google Drive

実行の様子を動画で表示する(HTML)

In [ ]: