陪你度过漫长岁月

技术总结《OpenAI Gym》

本文首先介绍Gym的核心函数调用链,然后介绍如何创建自定义的Gym环境,最后给出一些使用Gym过程中碰到的问题及其解决方案

Gym核心函数调用链

一般来说,使用Gym的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# main.py
import gym
def choose_action(o):
...
env = gym.make('CartPole-v0')
o = env.reset()
while True:
a = choose_action(o)
o_, r, done, info = env.step(a)
o = o_
if done:
break

可见,关键的函数有:

  • env = gym.make('CartPole-v0')
  • env.reset()
  • env.step(a)

我们先关注env.reset()env.step(a)。这两个函数是超类Env的成员函数,Env的相关代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# gym/core.py
class Env(object):
...
# Override in ALL subclasses
def _step(self, action): raise NotImplementedError
def _reset(self): raise NotImplementedError
...
def step(self, action):
return self._step(action)
def reset(self):
return self._reset()
...

可以看到这两个函数依赖于子类的_reset(self)_step(self, action)实现,子类CartPoleEnv的相关代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
# gym/envs/classic_control/CartPole.py
class CartPoleEnv(gym.Env):
...
def _step(self, action):
...
def _reset(self):
...
...

综上,env.reset()env.step(a)实际上是调用子类的_reset(self)_step(self, action)


下面我们关注gym.make('CartPole-v0'),它的实现如下:

1
2
3
4
5
6
7
8
9
# gym/envs/registration.py
# Have a global registry
registry = EnvRegistry()
...
def make(id):
return registry.make(id)

可以看到gym.make依赖于类EnvRegistry的成员函数makeEnvRegistry的相关代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# gym/envs/registration.py
class EnvRegistry(object):
def __init__(self):
# 注册表
# key: 环境名称(e.g., 'CartPole-v0')
# value:类型为EnvSpec,可以暂时理解为环境
self.env_specs = {}
def make(self, id):
...
# 根据环境名称,通过成员函数找到对应的环境
spec = self.spec(id)
# 实例化环境
env = spec.make()
...
return env
...
def spec(self, id):
...
...

可见类EnvRegistry的成员函数make依赖于类EnvSpec的成员函数makeEnvSpec的相关代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# gym/envs/registration.py
def load(name):
...
# EnvSpec与Env之间的关系类似于说明商品规格的订单与商品之间的关系,
# 下面用一个例子来说明:
# 假设你网购看中了一款衣服,那么你会挑选该款衣服的颜色、码数,然后再下单。
# 在这个例子里面,那款衣服就是Env,而说明该款衣服颜色、码数的订单就是EnvSpec。
# 这就是为什么EnvRegistry.make(self, id)中,在得到spec之后还要再spec.make(),
# 因为EnvSpec并不是Env,正如订单不是衣服。
class EnvSpec(object):
def __init__(self, id, entry_point=None, ...):
self.id = id
...
self._entry_point = entry_point
...
def make(self):
...
# 动态加载环境类
# 相当于以下代码
# from self._entry_point import classA
# cls = classA
cls = load(self._entry_point)
# 实例化环境
env = cls(**self._kwargs)
...
return env
...

至此,我们对Gym的核心函数调用链有了一个基本的了解:

  • gym.make(id):通过EnvRegistry中的注册表找到对应的EnvSpecEnvSpec根据entry_point动态import对应的Env,并将其实例化;
  • env.reset()env.step(a):子类的_reset(self)_step(self, action)

创建自定义环境

对Gym的核心函数调用链有了基本了解后,我们知道创建自定义环境的关键有两个:

  • 第一个是搭建自己的Env子类FooEnv
  • 第二个是注册FooEnv(i.e., 将FooEnv添加到registry.env_specs中),使得gym.make(id)可以找到FooEnv

官方文档推荐的自定义环境目录结构如下:

1
2
3
4
5
6
7
8
gym-foo/
README.md
setup.py #将gym_foo这个package加到系统环境变量中
gym_foo/ #核心部分
__init__.py #注册FooEnv
envs/
__init__.py
foo_env.py #实现FooEnv

实现FooEnv没什么特别的,就是根据自己的需求,实现_step(self, action)_reset(self)等函数。

值得一提的是注册FooEnv,我们无需自己实现注册环境的代码,因为Gym已经有现成的注册环境API,我们只需要调用该API即可。在我们的自定义环境中,负责注册FooEnv的文件为gym-foo/gym_foo/__init__.py,它的内容如下:

1
2
3
4
5
6
7
8
# gym-foo/gym_foo/__init__.py
from gym.envs.registration import register
register(
id='foo-v0', # 环境名
entry_point='gym_foo.envs:FooEnv', # 环境类,之后就根据这个路径动态import环境
)

可见,注册的关键是register函数,而register函数的实现如下:

1
2
3
4
5
6
7
8
9
10
11
# gym/envs/registration.py
# Have a global registry
registry = EnvRegistry()
# Gym的注册环境API
def register(id, **kwargs):
return registry.register(id, **kwargs)
def make(id):
return registry.make(id)

可以看到register的实现依赖于类EnvRegistry的成员函数register,其相关代码如下:

1
2
3
4
5
6
7
8
9
# gym/envs/registration.py
class EnvRegistry(object):
...
def register(self, id, **kwargs):
...
# 将FooEnv对应的“订单”写到“注册表”上
self.env_specs[id] = EnvSpec(id, **kwargs)

综上,我们可以通过API函数register注册自定义的环境FooEnv

注意事项

server render

假如你通过ssh连接server,在server上运行(i.e., python main.py)以下代码(关键点在使用env.render()保存录像):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# main.py
import gym
from gym import wrappers
env = gym.make('CartPole-v0')
env = wrappers.Monitor(env, 'video')
for i_episode in range(20):
observation = env.reset()
for t in range(100):
env.render()
action = env.action_space.sample()
observation, reward, done, info = env.step(action)
if done:
break

那么你会得到一个报错,报错的信息大概是pyglet.canvas.xlib.NoSuchDisplayException: Cannot connect to "None"

原因大概是env.render()需要图形界面(就是弹出来的那个框框),而当你使用ssh连接server时是没有图形界面的。因此我们需要一个虚拟的图形界面,而xvfb-run就是一个提供虚拟图形界面的工具。

所以我们需要使用xvfb-run -a -s "-screen 0 1400x900x24 +extension RANDR" -- python main.py来运行我们的代码。

一般来说,运行上述指令是会报错的,报错的信息大概是pyglet requires an X server with GLX,主要原因在于显卡驱动以及cuda的安装有问题,没有加--no-opengl的flag。解决方案可以参考这里这里

保存每一段episode的录像

wrappers.Monitor默认不会保存所有episode的录像,但我们可以通过以下代码来设置保存所有episode的录像:

1
env = wrappers.Monitor(env, 'video', video_callable=lambda episode_id: True)

动态修改episode的最大step

env._max_episode_steps = xxx。注意,这仅当env的类型为TimeLimit时可用。

关于wrapper

  • 相同的两个wrapper不能叠加(e.g., Monitor不可以和Monitor叠加,但是Monitor可以和TimeLimit叠加),否则会报double wrapper的错。
  • 在注册FooEnv时,加不加max_episode_steps=xxx会影响返回的Env的类型。假如加了,返回的是TimeLimit类型的wrapper;假如不加,返回的就是裸的FooEnv
  • Monitor里面有两个recorder,一个是stat_recorder,用于保存数据(reward之类的);另一个是video_recorder,用于录像。Monitor会在每一次调用env.resetenv.step之后调用render

屏蔽log信息

1
2
3
4
5
6
7
8
9
10
# main.py
import logging
# suppress INFO level logging 'Making new env: ...'
logging.getLogger('gym.envs.registration').setLevel(logging.WARNING)
# suppress INFO level logging 'Starting new video recorder writing to ...'
logging.getLogger('gym.monitoring.video_recorder').setLevel(logging.WARNING)
# suppress INFO level logging 'Creating monitor directory ...'
logging.getLogger('gym.wrappers.monitoring').setLevel(logging.WARNING)