GithubHelp home page GithubHelp logo

simple_reinforcement_learning's Introduction

simple_reinforcement_learning's People

Contributors

lansinuote avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

simple_reinforcement_learning's Issues

贪婪算法的改进

贪婪算法的改进,求动作函数下降探索概率的时候,这个played_count是不是应该减去rewards的长度啊?把初始化的长度减掉?

求大神解答。。

我按照AC算法结合之前的讲的写了一个倒立摆的AC算法,为什么效果奇差无比!!!!?求指正QAQ

import gym

#######################定义环境########################
#定义环境
class MyWrapper(gym.Wrapper):
def init(self):
env = gym.make('Pendulum-v1', render_mode='rgb_array')
super().init(env)
self.env = env
self.step_n = 0

def reset(self):
    state, _ = self.env.reset()
    self.step_n = 0
    return state

def step(self, action):
    state, reward, terminated, truncated, info = self.env.step(action)
    done = terminated or truncated
    self.step_n += 1
    if self.step_n >= 200:
        done = True
    return state, reward, done, info

env = MyWrapper()

####################展示游戏######################################

from matplotlib import pyplot as plt

%matplotlib inline

#打印游戏
def show():
plt.imshow(env.render())
plt.show()

##################创建网络#################################

创建两个网络,Actot、Critic

import torch
model_actor = torch.nn.Sequential(
torch.nn.Linear(3,128),
torch.nn.ReLU(),
torch.nn.Linear(128,11),
torch.nn.Softmax()
)
model_critic =torch.nn.Sequential(
torch.nn.Linear(3,128),
torch.nn.ReLU(),
torch.nn.Linear(128,1)
)

####################获取动作##################################

import random

def get_action(state):
state = torch.FloatTensor(state).reshape(1, 3)

prob = model_actor(state)

action = random.choices(range(11), weights=prob[0].tolist(),k=1)[0]
action_continuous = action
action_continuous /= 10
action_continuous *= 4
action_continuous -= 2

return action, action_continuous

########################获取数据##################################

def get_data():
states = []
rewards = []
actions = []
next_states = []
overs = []

# 初始化游戏
state = env.reset()
over = False
while not over:
    action, action_continuous = get_action(state)
    #print(f'执行动作:{action_continuous}')
    next_state, reward, over, _ = env.step([action_continuous])
    states.append(state)
    rewards.append(reward)
    actions.append(action)
    next_states.append(next_state)
    overs.append(over)

    state = next_state

states = torch.FloatTensor(states).reshape(-1, 3)

rewards = torch.FloatTensor(rewards).reshape(-1, 1)

actions = torch.LongTensor(actions).reshape(-1, 1)

next_states = torch.FloatTensor(next_states).reshape(-1, 3)

overs = torch.LongTensor(overs).reshape(-1, 1)

return states, rewards, actions, next_states, overs

#########################游戏测试#########################################
from IPython import display

def test(play):
#初始化游戏
state = env.reset()

#记录反馈值的和,这个值越大越好
reward_sum = 0

#玩到游戏结束为止
over = False

while not over:
    
    #根据当前状态得到一个动作
    action,action_continuous = get_action(state)

    #执行动作,得到反馈
    state, reward, over, _ = env.step([action_continuous])
    reward_sum += reward

    #打印动画
    if play: 
        display.clear_output(wait=True)
        show()

return reward_sum

###########################训练函数##################################
def train():
optimizer_actor = torch.optim.Adam(model_actor.parameters(), lr=2e-3)
optimizer_critic = torch.optim.Adam(model_critic.parameters(), lr=1e-2)

loss_fn = torch.nn.MSELoss()

for i in range(2000):
    states, rewards, actions, next_states, overs = get_data()

    values = model_critic(states)

    targets = model_critic(next_states)

    targets *= 0.98

    targets *= 1 - overs

    targets += rewards

    delta = (values - targets).detach()

    probs = model_actor(states)

    probs = probs.gather(dim=1, index=actions)

    loss = (-probs.log() * delta).mean()

    loss_critic =loss_fn(values,targets.detach())

    optimizer_actor.zero_grad()
    loss.backward()
    optimizer_actor.step()

    optimizer_critic.zero_grad()
    loss_critic.backward()
    optimizer_critic.step()

    if i % 100 == 0:
        test_result = sum([test(play=False) for _ in range(10)]) / 10
        print(f"epoch:{i},score:{test_result}")

train()
微信截图_20230525144652
[
13AC倒立摆.pdf
](url)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.