理论部分请看下方第二个参考链接视频很详细,也不长,代码来自百度高级工程师科老师, 说话好听最重要的是讲的好,查了一下科老师背景,北京大学 深圳学院(南燕,就职于百度的15级校友李科浇,估计27,28岁了),真的,真的,这个免费的公开课, 超出我的预期了。 项目结构 Sarsa_FileFolder ->agent.py ->gridworld.py ->train.py 科engineer在给毕业生的分享会的主要内容: 第二位分享的是2015级信息工程学院校友李科浇。她以开发测试工程师的身份进入了百度AI研究部, 目前成功转岗到同一部门的研发工程师岗位。 在分享中她首先系统介绍了从求职准备、海投、面试到选择的求职过程。随后分享了 《远见:如何规划职业生涯3大阶段》中对职场规划的看法,强调了在入职初期学习的重要性, 在职场中获取“燃料”持续赋能的方法,时间管理的思路,以及最重要的永远“拥抱变化”的心态。 最后,她总结了职场生活中感触最深的三个方面:做好工作、主动汇报;学会合作、建立关系; 持续成长、拥抱变化。
好了,马匹咱也拍完了,现在应该,谈正事儿。
Sarsa简说
Sarsa全称是state-action-reward-state’-action’,训练过程中不断的迭代 ,解释一下
是在什幺状态下,执行什幺动作可以拿到最大奖励,最终建立和优化一个Q表格,以state为行,action为列,根据与环境交互得到的reward来更新Q表格,更新公式为:
Sarsa在训练中为了更好的探索环境,采用ε-greedy方式来训练,有一定概率随机选择动作输出。
这里贴一份Sarsa伪代码,方便理解核心意思。
理解了核心思想就可以看代码实现细节了(如果没有理解到也能理解,去链接看老师的视频吧),老师讲的很详细,确实我这个主要是个笔记整理,记得给老师的PARL项目点个Star哦。
三个python文件都拷在一个文件夹内,直接运行train.py即可看到输出情况。
安装相关依赖包: pip install gym pip install paddle
1.agent.py
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -*- coding: utf-8 -*- import numpy as np class SarsaAgent(object): def __init__(self, obs_n, act_n, learning_rate=0.01, gamma=0.9, e_greed=0.1): self.act_n = act_n # 动作维度,有几个动作可选 self.lr = learning_rate # 学习率 self.gamma = gamma # reward的衰减率 self.epsilon = e_greed # 按一定概率随机选动作 self.Q = np.zeros((obs_n, act_n)) # 根据输入观察值,采样输出的动作值,带探索 def sample(self, obs): if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作 action = self.predict(obs) else: action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作 return action # 根据输入观察值,预测输出的动作值 def predict(self, obs): Q_list = self.Q[obs, :] maxQ = np.max(Q_list) action_list = np.where(Q_list == maxQ)[0] # maxQ可能对应多个action action = np.random.choice(action_list) return action # 学习方法,也就是更新Q-table的方法 def learn(self, obs, action, reward, next_obs, next_action, done): """ on-policy obs: 交互前的obs, s_t action: 本次交互选择的action, a_t reward: 本次动作获得的奖励r next_obs: 本次交互后的obs, s_t+1 next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1 done: episode是否结束 """ predict_Q = self.Q[obs, action] if done: target_Q = reward # 没有下一个状态了 else: target_Q = reward + self.gamma * self.Q[next_obs, next_action] # Sarsa self.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q def save(self): npy_file = './q_table.npy' np.save(npy_file, self.Q) print(npy_file + ' saved.') def restore(self, npy_file='./q_table.npy'): self.Q = np.load(npy_file) print(npy_file + ' loaded.')
2.gridworld.py (渲染CliffWalking-V0环境的一个包,就是可视化得好看一点,基于gym包的基础上,科科老师编写的)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -*- coding: utf-8 -*- import gym import turtle import numpy as np # turtle tutorial : https://docs.python.org/3.3/library/turtle.html def GridWorld(gridmap=None, is_slippery=False): if gridmap is None: gridmap = ['SFFF', 'FHFH', 'FFFH', 'HFFG'] env = gym.make("FrozenLake-v0", desc=gridmap, is_slippery=False) env = FrozenLakeWapper(env) return env class FrozenLakeWapper(gym.Wrapper): def __init__(self, env): gym.Wrapper.__init__(self, env) self.max_y = env.desc.shape[0] self.max_x = env.desc.shape[1] self.t = None self.unit = 50 def draw_box(self, x, y, fillcolor='', line_color='gray'): self.t.up() self.t.goto(x * self.unit, y * self.unit) self.t.color(line_color) self.t.fillcolor(fillcolor) self.t.setheading(90) self.t.down() self.t.begin_fill() for _ in range(4): self.t.forward(self.unit) self.t.right(90) self.t.end_fill() def move_player(self, x, y): self.t.up() self.t.setheading(90) self.t.fillcolor('red') self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit) def render(self): if self.t == None: self.t = turtle.Turtle() self.wn = turtle.Screen() self.wn.setup(self.unit * self.max_x + 100, self.unit * self.max_y + 100) self.wn.setworldcoordinates(0, 0, self.unit * self.max_x, self.unit * self.max_y) self.t.shape('circle') self.t.width(2) self.t.speed(0) self.t.color('gray') for i in range(self.desc.shape[0]): for j in range(self.desc.shape[1]): x = j y = self.max_y - 1 - i if self.desc[i][j] == b'S': # Start self.draw_box(x, y, 'white') elif self.desc[i][j] == b'F': # Frozen ice self.draw_box(x, y, 'white') elif self.desc[i][j] == b'G': # Goal self.draw_box(x, y, 'yellow') elif self.desc[i][j] == b'H': # Hole self.draw_box(x, y, 'black') else: self.draw_box(x, y, 'white') self.t.shape('turtle') x_pos = self.s % self.max_x y_pos = self.max_y - 1 - int(self.s / self.max_x) self.move_player(x_pos, y_pos) class CliffWalkingWapper(gym.Wrapper): def __init__(self, env): gym.Wrapper.__init__(self, env) self.t = None self.unit = 50 self.max_x = 12 self.max_y = 4 def draw_x_line(self, y, x0, x1, color='gray'): assert x1 > x0 self.t.color(color) self.t.setheading(0) self.t.up() self.t.goto(x0, y) self.t.down() self.t.forward(x1 - x0) def draw_y_line(self, x, y0, y1, color='gray'): assert y1 > y0 self.t.color(color) self.t.setheading(90) self.t.up() self.t.goto(x, y0) self.t.down() self.t.forward(y1 - y0) def draw_box(self, x, y, fillcolor='', line_color='gray'): self.t.up() self.t.goto(x * self.unit, y * self.unit) self.t.color(line_color) self.t.fillcolor(fillcolor) self.t.setheading(90) self.t.down() self.t.begin_fill() for i in range(4): self.t.forward(self.unit) self.t.right(90) self.t.end_fill() def move_player(self, x, y): self.t.up() self.t.setheading(90) self.t.fillcolor('red') self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit) def render(self): if self.t == None: self.t = turtle.Turtle() self.wn = turtle.Screen() self.wn.setup(self.unit * self.max_x + 100, self.unit * self.max_y + 100) self.wn.setworldcoordinates(0, 0, self.unit * self.max_x, self.unit * self.max_y) self.t.shape('circle') self.t.width(2) self.t.speed(0) self.t.color('gray') for _ in range(2): self.t.forward(self.max_x * self.unit) self.t.left(90) self.t.forward(self.max_y * self.unit) self.t.left(90) for i in range(1, self.max_y): self.draw_x_line( y=i * self.unit, x0=0, x1=self.max_x * self.unit) for i in range(1, self.max_x): self.draw_y_line( x=i * self.unit, y0=0, y1=self.max_y * self.unit) for i in range(1, self.max_x - 1): self.draw_box(i, 0, 'black') self.draw_box(self.max_x - 1, 0, 'yellow') self.t.shape('turtle') x_pos = self.s % self.max_x y_pos = self.max_y - 1 - int(self.s / self.max_x) self.move_player(x_pos, y_pos) if __name__ == '__main__': # 环境1:FrozenLake, 可以配置冰面是否是滑的 # 0 left, 1 down, 2 right, 3 up env = gym.make("FrozenLake-v0", is_slippery=False) env = FrozenLakeWapper(env) # 环境2:CliffWalking, 悬崖环境 # env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left # env = CliffWalkingWapper(env) # 环境3:自定义格子世界,可以配置地图, S为出发点Start, F为平地Floor, H为洞Hole, G为出口目标Goal # gridmap = [ # 'SFFF', # 'FHFF', # 'FFFF', # 'HFGF' ] # env = GridWorld(gridmap) env.reset() for step in range(10): action = np.random.randint(0, 4) obs, reward, done, info = env.step(action) print('step {}: action {}, obs {}, reward {}, done {}, info {}'.format(\ step, action, obs, reward, done, info)) # env.render() # 渲染一帧图像
3.train.py
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -*- coding: utf-8 -*- import gym from gridworld import CliffWalkingWapper, FrozenLakeWapper from agent import SarsaAgent import time def run_episode(env, agent, render=False): total_steps = 0 # 记录每个episode走了多少step total_reward = 0 obs = env.reset() # 重置环境, 重新开一局(即开始新的一个episode) action = agent.sample(obs) # 根据算法选择一个动作 while True: next_obs, reward, done, _ = env.step(action) # 与环境进行一个交互 next_action = agent.sample(next_obs) # 根据算法选择一个动作 # 训练 Sarsa 算法 agent.learn(obs, action, reward, next_obs, next_action, done) action = next_action obs = next_obs # 存储上一个观察值 total_reward += reward total_steps += 1 # 计算step数 if render: env.render() #渲染新的一帧图形 if done: break return total_reward, total_steps def test_episode(env, agent): total_reward = 0 obs = env.reset() while True: action = agent.predict(obs) # greedy next_obs, reward, done, _ = env.step(action) total_reward += reward obs = next_obs time.sleep(0.5) env.render() if done: print('test reward = %.1f' % (total_reward)) break def main(): # env = gym.make("FrozenLake-v0", is_slippery=False) # 0 left, 1 down, 2 right, 3 up # env = FrozenLakeWapper(env) env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left env = CliffWalkingWapper(env) agent = SarsaAgent( obs_n=env.observation_space.n, act_n=env.action_space.n, learning_rate=0.1, gamma=0.9, e_greed=0.1) is_render = False for episode in range(500): ep_reward, ep_steps = run_episode(env, agent, is_render) print('Episode %s: steps = %s , reward = %.1f' % (episode, ep_steps, ep_reward)) # 每隔20个episode渲染一下看看效果 if episode % 20 == 0: is_render = True else: is_render = False # 训练结束,查看算法效果 test_episode(env, agent) if __name__ == "__main__": main()
参考资料(如果确实对您有用的话,请点赞支持一下)
这节课的视频讲解,从理论到代码都讲的很细了,对应课名,Lesson2_Sarsa。
强化学习 Sarsa 实战GYM下的CliffWalking爬悬崖游戏_Xurui_Luo的博客-CSDN博客_cliff walking
强化学习实战-使用Sarsa算法解决悬崖问题_心流-CSDN博客
【/强化学习7日打卡营-世界冠军带你从零实践/课程摘要和调参心得-No.1】强化学习初印象_FlyingPie的专栏-CSDN博客
Be First to Comment