matplotlib 入門 (14) nitta@tsuda.ac.jp

14章: アニメーション (Google Colab 対応) (2)

OpenAI Gym Envs on Google Colab

OpenAI Gym Envs を Google Colab 上でアニメーション表示するためには、PyVirtualDisplay 経由で仮想ディスプレイ Xvfb を使う必要がある。

In [ ]:
is_colab = 'google.colab' in str(get_ipython())   # for Google Colab
In [ ]:
if is_colab:
    !apt update -qq
    !apt install -qq xvfb
    !pip -q install pyvirtualdisplay
63 packages can be upgraded. Run 'apt list --upgradable' to see them.
xvfb is already the newest version (2:1.19.6-1ubuntu4.10).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 63 not upgraded.

1フレームずつ画像を表示してアニメーション

In [ ]:
%matplotlib inline
import gym
from IPython import display
import matplotlib.pyplot as plt

if is_colab:
    from pyvirtualdisplay import Display
    vdisplay = Display()
    vdisplay.start()

env = gym.make('CartPole-v1')

env.reset()

fig = plt.figure(figsize=(8,6))
plt.axis('off')
img = plt.imshow(env.render(mode='rgb_array'))
for _ in range(100):
    observation, reward, done, info = env.step(env.action_space.sample()) # needs action from DNN

    display.clear_output(wait=True)
    img.set_data(env.render(mode='rgb_array'))

    display.display(plt.gcf())

    if done:
        env.reset()
        
plt.close()

FuncAnimation

In [ ]:
import gym
from IPython import display
import matplotlib.pyplot as plt
from matplotlib import animation

if is_colab:
    from pyvirtualdisplay import Display
    vdisplay = Display()
    vdisplay.start()

env = gym.make('CartPole-v1')
env.reset()

plt.figure(figsize=(8,6))
patch = plt.imshow(env.render(mode='rgb_array'))
plt.axis('off')

def animate(i):
    observation, reward, done, info = env.step(env.action_space.sample()) # needs action from DNN
    display.clear_output(wait=True)
    patch.set_data(env.render(mode='rgb_array'))
    if done:
        env.reset()
    return patch


anim = animation.FuncAnimation(plt.gcf(), animate, frames=100, interval=50, repeat=False)
display.display(display.HTML(anim.to_jshtml()))

plt.close()
In [ ]: