用50行Python代码解决cart pole平衡问题


作者:Mike Shi
编译:Bing
今天的这篇文章向大家展示,如何用50行Python代码教会机器解决cart pole问题,保持平衡。原文作者Mike Shi将用标准的OpenAI Gym作为测试环境,仅用NumPy创建我们的智能体。
Cart pole的玩法如下动图所示,目标就是保持一根杆一直竖直朝上,杆由于重力原因会一直倾斜,当杆倾斜到一定程度就会倒下,此时需要朝左或者右移动杆保证它不会倒下来。这和在指尖上树立一只铅笔一样,只不过cart pole是一维的。

在详细讲解前可以先试一下我们最终的demo:

地址: towardsdatascience.com/from-scratch-ai-balancing-act-in-50-lines-of-python-7ea67ef717

强化学习速成

如果你是机器学习新手,或是首次接触强化学习,我会在这里介绍一些可能用到的基础术语。如果你已经很熟悉强化学习的概念了,可以跳过这部分。
强化学习
强化学习的主要任务是让智能体在没有明确指令的情况下学会执行特定任务,或作出特定动作。可以把它想象成一个婴儿,它正随机地伸展腿部,在一次偶然的情况中,婴儿可以竖直站立了,这时我们会给他一颗糖作为奖励。同样,智能体的目标就是在有限的时间内实现总体奖励的最大化,并且我们要根据想完成的任务决定奖励类型。对于婴儿站立的案例,如果直立站立,奖励就是1,否则的话就是0。
强化学习智能体的一个例子就是AlphaGo,该智能体学习如何玩游戏才能使奖励最大化(赢得游戏)。在这篇教程中,我们将创造一个智能体,通过左右移动杆子,解决cart pole问题。
状态

游戏《乓》的界面

状态是某一时刻游戏的样子,我们通常会处理游戏的多种表示。在《乓》中状态可能是球拍的竖直位置和x, y坐标以及乒乓球的移动速度。在cart pole游戏中,我们的状态由四个数字组成:底部小车(cart)的位置,小车的速度,杆子的位置(用角度表示)和杆子角度变化的速度。这四组数据以群组的方式(或向量)给定出来。这非常重要,理解这组表示状态的数字意味着我们可以用数学推理决定接下来做出怎样的反应。
策略
策略是一种函数,输入游戏的状态(例如,位置参数),输出智能体应该做出的动作。智能体采取我们所选择动作后,游戏会根据下一个状态进行更新,这一过程会一直持续到游戏结束。策略是非常重要的,也是我们一直追求的,这是智能体背后的决策能力体现。
点积
两数组(向量)间的点积就是简单地将第一组中的每个元素和对应的元素相乘,然后把它们结合在一起假设我们想找到数组A和B之间的点积,只需要简单计算A[0]B[0] + A[1]B[1]…即可。我们会用这个公式将状态(数组)和另一个数组(我们的策略)相乘。

创建我们的策略

为了解决cart pole问题,我们想让机器学习一种策略赢得最大奖励。
对于我们要创建的智能体,我们会用四个数字来表示策略,这四个数字可以体现状态中的每个元素的重要性(例如小车的位置、杆子的位置等等),之后,我们会计算策略数组和状态之间的点击,输出一个单一数字。根据数值的正负,我们决定让小车向左还是向右。
如果这种描述听起来比较抽象,那我们接下来用一个具体例子来展示这一过程。
假设底部小车在中央是静止的,杆子向右倾斜,并且可能会倒向右边:

相关的状态可能会如下:

那么状态的数组可能是[0, 0, 0.2, 0.05]。
我们的直觉是,要想让杆子竖直,就要将小车推向右。我在训练过程中得到了一个策略结果,它的数组如下:[-0.116, 0.332, 0.207, 0.352]。让我们快速地计算一下,看看在这一状态下会输出怎样的动作。
这里,我们将上述策略数组和状态[0, 0, 0.2, 0.05]结合计算点积。如果得出的结果是正的,那么就将小车向右推,如果是负的就向左。

结果是正数,也就是说该策略会让将小车推到右边,就像我们做的一样。现在,问题时我们如何才能得到那四个数字?如果随机选择会怎样?模型会表现得怎么样?

开始你的编辑

首先在repl.it上打开一个Python实例,你能从中获取大量不同的编程环境实例,然后用强大的云IDE编写代码。

安装软件包

首先我们要安装两个必要的包:帮助进行数值计算的NumPy,和智能体模拟器OpenAI Gym。

搭建基础框架

首先,导入两个在 main.py 脚本中安装的依赖,然后设置一个新的gym环境:

import gym
import numpy as np
env = gym.make('CartPole-v1')

接下来,我们将定义一个名为“play”的函数,其中将有一个环境和策略数组,并且在环境中计算策略数组,并返回一个分数,在每次游戏迭代时都会进行记录。我们会根据分数判断策略的效果,并根据每次的游戏记录判断策略的表现。这就是在游戏中如何测试不同策略,并判断它们效果的方法。
首先,让我们清楚函数的定义,再把游戏设置为初始状态。

def play(env, policy):
  observation = env.reset()

接下来,我们要设立一些变量,进行追踪,观察游戏是否已经结束,包括策略的总分、游戏中每一步的快照。

done = False
  score = 0
  observations = []

现在,已经对游戏进行了多次运行,直到gym告诉我们游戏已经完成。

for _ in range(5000):
    observations += [observation.tolist()] # Record the observations for normalization and replay
 if done: # If the simulation was over last iteration, exit loop
 break
 # Pick an action according to the policy matrix
    outcome = np.dot(policy, observation)
    action = 1 if outcome > 0 else 0
 # Make the action, record reward
    observation, reward, done, info = env.step(action)
    score += reward
 return score, observations

上述代码主要是玩游戏的过程以及记录输出,实际上,我们的策略代码只需要两行:

outcome = np.dot(policy, observation)
    action = 1 if outcome > 0 else 0

这里我们主要是进行策略数组和状态数组之间的点积运算,之后,我们会根据结果进行动作选择。
目前为止,我们的 main.py 应该如下所示:

import gym
import numpy as np
env = gym.make('CartPole-v1')
def play(env, policy):
  observation = env.reset()
  done = False
  score = 0
  observations = []
 for _ in range(5000):
    observations += [observation.tolist()] # Record the observations for normalization and replay
 if done: # If the simulation was over last iteration, exit loop
 break
 # Pick an action according to the policy matrix
    outcome = np.dot(policy, observation)
    action = 1 if outcome > 0 else 0
 # Make the action, record reward
    observation, reward, done, info = env.step(action)
    score += reward
 return score, observations

现在我们开始玩游戏,找到最佳策略!

开始第一局游戏

现在我们有了能玩游戏,并且能判断策略好坏的函数,现在我们想生成一些其他策略,看看它们能做什么。
如果输入一些随机策略会是怎样?我们能得到什么结果?用 numpy 生成我们的策略,这是一个含有四个元素的数组,或者1×4矩阵。

policy = np.random.rand(1,4)

得到了策略和上述我们创造的环境,我们可以开始游戏,并得到一个分数:

score, observations = play(env, policy)
print('Policy Score', score)

运行脚本后,它应该输出我们策略得到的分数。

游戏最大分是500。

观察智能体

为了观察智能体,我们用Flask设置一个轻量级服务器,可以在浏览器中看到智能体的表现。Flask是一款简洁的Python HTTP服务器框架,可以服务我们的HTML UI和数据。
首先安装Flask的Python包:

接着,在我们脚本的最下方创建一个flask服务器。他能在端点 /data 上显示游戏中每一帧的记录,在 / 上显示UI。

from flask import Flask
import json
app = Flask(__name__, static_folder='.')
@app.route("/data")
def data():
 return json.dumps(observations)
@app.route('/')
def root():
 return app.send_static_file('./index.html')
app.run(host='0.0.0.0', port='3000')

同时,我们需要添加两个文件夹,其中一个是空白的Python文件夹。接着我们还想创建一个index.html,可以渲染UI。具体过程这里不详细展开说明了,但是你需要将这个index.html上传到你的repl.it项目中。
你现在应该有了这样的项目文件夹:

有了这两个文件夹,当我们运行repl时,它仍然可以演示我们的策略。一切就绪后,就能尝试超出最佳策略了。

策略搜索

最初我们随机选择了策略,但如果我们在很多策略中,只保留表现最佳的那个会怎样?
让我们回到运行策略的部分,编写一段循环,生成多个策略,并记录它们的表现,只保存表现最佳的策略。
我们首先要创造一个名为max的元组,可以存储分数、观察到的场景和目前最佳的策略数组。

max = (0, [], [])

接下来,我们会生成10种策略并进行评估,保存最佳策略。

for _ in range(10):
  policy = np.random.rand(1,4)
  score, observations = play(env, policy)
 if score > max[0]:
    max = (score, observations, policy)
print('Max Score', max[0])

我们同样要告诉/data端点,返回并重新演示最佳策略。该端点为:

@app.route("/data")
def data():
 return json.dumps(observations)

经过转变成为:

@app.route("/data")
def data():
 return json.dumps(max[1])

现在你的main.py看起来应该是这样:

import gym
import numpy as np
env = gym.make('CartPole-v1')
def play(env, policy):
  observation = env.reset()
  done = False
  score = 0
  observations = []
 for _ in range(5000):
    observations += [observation.tolist()] # Record the observations for normalization and replay
 if done: # If the simulation was over last iteration, exit loop
 break
 # Pick an action according to the policy matrix
    outcome = np.dot(policy, observation)
    action = 1 if outcome > 0 else 0
 # Make the action, record reward
    observation, reward, done, info = env.step(action)
    score += reward
 return score, observations
max = (0, [], [])
for _ in range(10):
  policy = np.random.rand(1,4)
  score, observations = play(env, policy)
 if score > max[0]:
    max = (score, observations, policy)
print('Max Score', max[0])
from flask import Flask
import json
app = Flask(__name__, static_folder='.')
@app.route("/data")
def data():
 return json.dumps(max[1])
@app.route('/')
def root():
 return app.send_static_file('./index.html')
app.run(host='0.0.0.0', port='3000')

如果现在运行repl,我们应该会得到的最大分数为500。如果没有,可以试试再运行repl一遍!

不足之处

但是,我们在第一部分进行了小小的作弊。首先,我们随机创造了策略数组,范围在0到1之间,这恰巧能运行。但如果我们稍稍改变数值,智能体就会完全失效,可以试一试将 action = 1 if outcome > 0 else 0 改为 action = 1 if outcome < 0 else 0 。这样就很不稳定,为了解决这一问题,我们应该提出一种对负数也能运行的策略。这样难度就又增加了,但是算法之后也能更通用。
我们将 policy = np.random.rand(1,4) 改为 policy = np.random.rand(1,4) - 0.5 ,把策略中的每个数值从0—1改为-0.5—0.5。我们还想让更多策略能够进行搜索,所以在上述循环中,与其在10种策略中进行迭代,我们将策略增加到100种,代码改为 for _ in range(100):
最终main.py变成了:

import gym
import numpy as np
env = gym.make('CartPole-v1')
def play(env, policy):
  observation = env.reset()
  done = False
  score = 0
  observations = []
 for _ in range(5000):
    observations += [observation.tolist()] # Record the observations for normalization and replay
 if done: # If the simulation was over last iteration, exit loop
 break
 # Pick an action according to the policy matrix
    outcome = np.dot(policy, observation)
    action = 1 if outcome > 0 else 0
 # Make the action, record reward
    observation, reward, done, info = env.step(action)
    score += reward
 return score, observations
max = (0, [], [])
# We changed the next two lines!
for _ in range(100):
  policy = np.random.rand(1,4) - 0.5
  score, observations = play(env, policy)
 if score > max[0]:
    max = (score, observations, policy)
print('Max Score', max[0])
from flask import Flask
import json
app = Flask(__name__, static_folder='.')
@app.route("/data")
def data():
 return json.dumps(max[1])
@app.route('/')
def root():
 return app.send_static_file('./index.html')
app.run(host='0.0.0.0', port='3000')

另外一点,既然我们的策略可以在单次运行中达到最高500分,它能保证每次都表现得这么好吗?当我们生成了100中策略,然后选择表现最佳的策略,但这一策略可能恰巧表现得很好,这是由于游戏本身具有一个随机元素(起始位置每次都不同),所以一种策略只在同一个起始点表现不错,而其他起始点可能就无效了。
为了解决这一点,我们想要评测一种策略能在多次试验中的表现。现在我们选取此前表现最好的策略,看看在100次实验中的表现如何:

scores = []
for _ in range(100):
  score, _  = play(env, max[2])
  scores += [score]
print('Average Score (100 trials)', np.mean(scores))

我们记录下了每次的分数,然后用numpy计算平均分,在终端上打印出来。

结语

恭喜:tada:,现在我们已经成功地创建一个可以解决cart pole问题的AI了,不仅有用而且非常高效。接下来,这一模型还有几处需要改进的空间:

  • 找到一个真正的最优策略
  • 减少找到最优策略的次数
  • 研究如何能找到正确的策略,而不是随即进行选择
  • 解决其他环境。

发表评论

电子邮件地址不会被公开。 必填项已用*标注