import random import gym import torch import numpy as np from matplotlib import pyplot as plt from IPython import display env = gym.make("CartPole-v0") # 智能体状态 state = env.reset() # 动作空间 actions = env.action_space.n print(state, actions) # 打印游戏 # plt.imshow(env.render(mode='rgb_array')) # plt.show() # 定义动作模型(策略网络) model = torch.nn.Sequential(torch.nn.Linear(4, 128), torch.nn.ReLU(), torch.nn.Linear(128, 2)) # 经验网络,评估一个动作的分数(目标网络) next_model = torch.nn.Sequential(torch.nn.Linear(4, 128), torch.nn.ReLU(), torch.nn.Linear(128, 2)) # model的参数赋予next_model next_model.load_state_dict(model.state_dict()) # 得到一个动作 def get_action(state): """state: agent所处的状态""" if random.random() < .1: return random.choice(range(2)) # 走神经网络NN,得到分值最大的那个动作。转为tensor数据 state = torch.FloatTensor(state).reshape(1, 4) return model(state).argmax().item() # 数据池 datas = [] def update_data(): """加入新的N条数据,删除最老的M条数据""" count = len(datas) while len(datas) - count < 200: # 一直追加数据,尽可能多的获取环境状态 state = env.reset() done = False while not done: # 由初始状态开始得到一个动作 action = get_action(state) next_state, reward, done, _ = env.step(action) datas.append((state, action, reward, next_state, done)) # 更新状态 state = next_state # 此时新数据集中比原来多了大约200条样本,如果超过了最大容量,删除最开始数据 update_count = len(datas) - count while len(datas) > 10000: datas.pop(0) return update_count # 从数据池中采样 def get_sample(): # batch size = 64, 数据类型转换为Tensor samples = random.sample(datas, 64) state = torch.FloatTensor([i[0] for i in samples]) action = torch.LongTensor([i[1] for i in samples]) reward = torch.FloatTensor([i[2] for i in samples]) next_state = torch.FloatTensor([i[3] for i in samples]) done = torch.LongTensor([i[4] for i in samples]) return state, action, reward, next_state, done # 获取动作价值 def get_value(state, action): """根据网络输出找到对应动作的得分,使用策略网络""" value = model(state) value = value[range(64), action] return value # 获取学习目标值 def get_target(next_state, reward, done): """使用next_state和reward计算真实得分。对价值的估计,使用目标网络""" with torch.no_grad(): next_value = next_model(next_state) # 贪心选取最大价值 target = next_value.max(dim=1)[0] # 如果next_state已经游戏结束,则其target得分为0 for i in range(64): if done[i]: target[i] = 0 target = reward + target * 0.98 return target # 一局游戏得分测试 def test(): reward_sum = 0 state = env.reset() done = False while not done: action = get_action(state) next_state, reward, done, _ = env.step(action) reward_sum += reward state = next_state return reward_sum def train(): model.train() optimizer = torch.optim.Adam(model.parameters(), lr=2e-3) loss_fn = torch.nn.MSELoss() for epoch in range(600): # 更新一批数据 update_counter = update_data() # 更新过数据后,学习N次 for i in range(200): state, action, reward, next_state, done = get_sample() # 计算value和target value = get_value(state, action) target = get_target(next_state, reward, done) # 参数更新 loss = loss_fn(value, target) optimizer.zero_grad() loss.backward() optimizer.step() """周期性更新目标网络""" if (i + 1) % 10 == 0: next_model.load_state_dict(model.state_dict()) if epoch % 50 == 0: test_score = sum([test() for i in range(50)]) / 50 print(epoch, len(datas), update_counter, test_score)
平均得分更高,效果好于单模型
原文地址:http://www.cnblogs.com/demo-deng/p/16886582.html
1. 本站所有资源来源于用户上传和网络,如有侵权请邮件联系站长!
2. 分享目的仅供大家学习和交流,请务用于商业用途!
3. 如果你也有好源码或者教程,可以到用户中心发布,分享有积分奖励和额外收入!
4. 本站提供的源码、模板、插件等等其他资源,都不包含技术服务请大家谅解!
5. 如有链接无法下载、失效或广告,请联系管理员处理!
6. 本站资源售价只是赞助,收取费用仅维持本站的日常运营所需!
7. 如遇到加密压缩包,默认解压密码为"gltf",如遇到无法解压的请联系管理员!
8. 因为资源和程序源码均为可复制品,所以不支持任何理由的退款兑现,请斟酌后支付下载
声明:如果标题没有注明"已测试"或者"测试可用"等字样的资源源码均未经过站长测试.特别注意没有标注的源码不保证任何可用性