paddle2.2.0:policy gradient算法实现

        在前面的博客中,我们使用了DQN等算法训练了agent并得到了较高的分数。DQN中的神经网络是输出的动作Q值,然后通过哪个Q值更大,就采取相应的动作,可我们为什么不直接让神经网络输出动作(概率),一步到位呢。而Policy Gradient就可以一步到位。

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import parl
import numpy as np
import gym
from parl.utils import logger
from paddle.distribution import Categorical

LEARNING_RATE = 1e-3

class Model(parl.Model):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        hid1_size = act_dim * 10
        self.fc1 = nn.Linear(obs_dim, hid1_size)
        self.fc2 = nn.Linear(hid1_size, act_dim)

    def forward(self, obs):
        out = F.tanh(self.fc1(obs))
        out = F.softmax(self.fc2(out))
        return out

class PolicyGradient(parl.Algorithm):
    def __init__(self, model, lr=None):
        self.model = model
        assert isinstance(lr, float)
        self.optimizer = paddle.optimizer.Adam(learning_rate=lr, parameters=model.parameters())

    def predict(self, obs):
        return self.model(obs)

    def learn(self, obs, act, reward):
        # act_prob = self.model(obs)
        # log_prob = F.cross_entropy(act_prob, act)
        # loss = log_prob.mean()
        # self.optimizer.clear_grad()
        # loss.backward()
        # self.optimizer.step()
        prob = self.model(obs)
        log_prob = Categorical(prob).log_prob(act)
        loss = paddle.mean(-1 * log_prob * reward)

        self.optimizer.clear_grad()
        loss.backward()
        self.optimizer.step()
        return loss

class Agent(parl.Agent):
    def __init__(self, algorithm, obs_dim, act_dim):
        super().__init__(algorithm)
        self.obs_dim = obs_dim
        self.act_dim = act_dim


    def sample(self, obs):
        obs = paddle.to_tensor(obs, dtype='float32')
        act_prob = self.alg.predict(obs)
        act_prob = np.squeeze(act_prob, axis=0)
        
        act = np.random.choice(range(self.act_dim), p=act_prob.numpy())
        return act
    
    def predict(self, obs):
        obs = paddle.to_tensor(obs, dtype='float32')
        act_prob = self.alg.predict(obs)
        act = np.argmax(act_prob)
        return act

    def learn(self, obs, act, reward):
        act = np.expand_dims(act, axis=-1)
        reward = np.expand_dims(reward, axis=-1)

        obs = paddle.to_tensor(obs, dtype='float32')
        act = paddle.to_tensor(act, dtype='int32')
        reward = paddle.to_tensor(reward, dtype='float32')
        loss = self.alg.learn(obs, act, reward)
        return loss.numpy()[0]

def run_episode(env, agent):
    obs_list, action_list, reward_list = [], [], []
    obs = env.reset()
    while True:
        obs_list.append(obs)
        action = agent.sample(obs) # 采样动作
        action_list.append(action)

        obs, reward, done, info = env.step(action)
        reward_list.append(reward)

        if done:
            break
    return obs_list, action_list, reward_list


# 评估 agent, 跑 5 个episode,总reward求平均
def evaluate(env, agent, render=False):
    eval_reward = []
    for i in range(5):
        obs = env.reset()
        episode_reward = 0
        while True:
            action = agent.predict(obs) # 选取最优动作
            obs, reward, isOver, _ = env.step(action)
            episode_reward += reward
            if render:
                env.render()
            if isOver:
                break
        eval_reward.append(episode_reward)
    return np.mean(eval_reward)

# 根据一个episode的每个step的reward列表,计算每一个Step的Gt
def calc_reward_to_go(reward_list, gamma=1.0):
    for i in range(len(reward_list) - 2, -1, -1):
        # G_t = r_t + γ·r_t+1 + ... = r_t + γ·G_t+1
        reward_list[i] += gamma * reward_list[i + 1]  # Gt
    return np.array(reward_list)


# 创建环境
env = gym.make('CartPole-v0')
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n
logger.info('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))

# 根据parl框架构建agent
model = Model(obs_dim, act_dim)
alg = PolicyGradient(model, lr=LEARNING_RATE)
agent = Agent(alg, obs_dim=obs_dim, act_dim=act_dim)

# 加载模型
# if os.path.exists('./model.ckpt'):
#     agent.restore('./model.ckpt')
#     run_episode(env, agent, train_or_test='test', render=True)
#     exit()

for i in range(1000):
    obs_list, action_list, reward_list = run_episode(env, agent)
    if i % 10 == 0:
        logger.info("Episode {}, Reward Sum {}.".format(
            i, sum(reward_list)))

    batch_obs = np.array(obs_list)
    batch_action = np.array(action_list)
    batch_reward = calc_reward_to_go(reward_list)

    agent.learn(batch_obs, batch_action, batch_reward)
    if (i + 1) % 100 == 0:
        total_reward = evaluate(env, agent, render=False) # render=True 查看渲染效果,需要在本地运行,AIStudio无法显示
        logger.info('Test reward: {}'.format(total_reward))

# 保存模型到文件 ./model.ckpt
agent.save('./model.ckpt')

输出:策略梯度算法收敛的特别快,训练了几十秒就基本收敛了,下面是agent的训练表现

[12-02 10:05:40 MainThread @3974053785.py:128] obs_dim 4, act_dim 2
[12-02 10:05:40 MainThread @machine_info.py:88] nvidia-smi -L found gpu count: 1
[12-02 10:05:40 MainThread @3974053785.py:145] Episode 0, Reward Sum 34.0.
[12-02 10:05:40 MainThread @3974053785.py:145] Episode 10, Reward Sum 12.0.
[12-02 10:05:41 MainThread @3974053785.py:145] Episode 20, Reward Sum 11.0.
[12-02 10:05:41 MainThread @3974053785.py:145] Episode 30, Reward Sum 18.0.
[12-02 10:05:41 MainThread @3974053785.py:145] Episode 40, Reward Sum 20.0.
[12-02 10:05:41 MainThread @3974053785.py:145] Episode 50, Reward Sum 31.0.
[12-02 10:05:41 MainThread @3974053785.py:145] Episode 60, Reward Sum 40.0.
[12-02 10:05:41 MainThread @3974053785.py:145] Episode 70, Reward Sum 16.0.
[12-02 10:05:42 MainThread @3974053785.py:145] Episode 80, Reward Sum 16.0.
[12-02 10:05:42 MainThread @3974053785.py:145] Episode 90, Reward Sum 22.0.
[12-02 10:05:42 MainThread @3974053785.py:154] Test reward: 45.2
[12-02 10:05:42 MainThread @3974053785.py:145] Episode 100, Reward Sum 13.0.
[12-02 10:05:42 MainThread @3974053785.py:145] Episode 110, Reward Sum 14.0.
[12-02 10:05:43 MainThread @3974053785.py:145] Episode 120, Reward Sum 68.0.
[12-02 10:05:43 MainThread @3974053785.py:145] Episode 130, Reward Sum 28.0.
[12-02 10:05:43 MainThread @3974053785.py:145] Episode 140, Reward Sum 25.0.
[12-02 10:05:43 MainThread @3974053785.py:145] Episode 150, Reward Sum 55.0.
[12-02 10:05:44 MainThread @3974053785.py:145] Episode 160, Reward Sum 87.0.
[12-02 10:05:44 MainThread @3974053785.py:145] Episode 170, Reward Sum 35.0.
[12-02 10:05:44 MainThread @3974053785.py:145] Episode 180, Reward Sum 59.0.
[12-02 10:05:44 MainThread @3974053785.py:145] Episode 190, Reward Sum 40.0.
[12-02 10:05:45 MainThread @3974053785.py:154] Test reward: 81.0
[12-02 10:05:45 MainThread @3974053785.py:145] Episode 200, Reward Sum 63.0.
[12-02 10:05:45 MainThread @3974053785.py:145] Episode 210, Reward Sum 22.0.
[12-02 10:05:45 MainThread @3974053785.py:145] Episode 220, Reward Sum 86.0.
[12-02 10:05:46 MainThread @3974053785.py:145] Episode 230, Reward Sum 65.0.
[12-02 10:05:46 MainThread @3974053785.py:145] Episode 240, Reward Sum 24.0.
[12-02 10:05:46 MainThread @3974053785.py:145] Episode 250, Reward Sum 26.0.
[12-02 10:05:46 MainThread @3974053785.py:145] Episode 260, Reward Sum 34.0.
[12-02 10:05:47 MainThread @3974053785.py:145] Episode 270, Reward Sum 70.0.
[12-02 10:05:47 MainThread @3974053785.py:145] Episode 280, Reward Sum 37.0.
[12-02 10:05:47 MainThread @3974053785.py:145] Episode 290, Reward Sum 43.0.
[12-02 10:05:48 MainThread @3974053785.py:154] Test reward: 98.0
[12-02 10:05:48 MainThread @3974053785.py:145] Episode 300, Reward Sum 33.0.
[12-02 10:05:48 MainThread @3974053785.py:145] Episode 310, Reward Sum 49.0.
[12-02 10:05:49 MainThread @3974053785.py:145] Episode 320, Reward Sum 66.0.
[12-02 10:05:49 MainThread @3974053785.py:145] Episode 330, Reward Sum 54.0.
[12-02 10:05:49 MainThread @3974053785.py:145] Episode 340, Reward Sum 98.0.
[12-02 10:05:50 MainThread @3974053785.py:145] Episode 350, Reward Sum 81.0.
[12-02 10:05:50 MainThread @3974053785.py:145] Episode 360, Reward Sum 78.0.
[12-02 10:05:50 MainThread @3974053785.py:145] Episode 370, Reward Sum 112.0.
[12-02 10:05:51 MainThread @3974053785.py:145] Episode 380, Reward Sum 40.0.
[12-02 10:05:51 MainThread @3974053785.py:145] Episode 390, Reward Sum 37.0.
[12-02 10:05:52 MainThread @3974053785.py:154] Test reward: 104.4
[12-02 10:05:52 MainThread @3974053785.py:145] Episode 400, Reward Sum 74.0.
[12-02 10:05:52 MainThread @3974053785.py:145] Episode 410, Reward Sum 44.0.
[12-02 10:05:52 MainThread @3974053785.py:145] Episode 420, Reward Sum 39.0.
[12-02 10:05:53 MainThread @3974053785.py:145] Episode 430, Reward Sum 104.0.
[12-02 10:05:53 MainThread @3974053785.py:145] Episode 440, Reward Sum 22.0.
[12-02 10:05:54 MainThread @3974053785.py:145] Episode 450, Reward Sum 92.0.
[12-02 10:05:54 MainThread @3974053785.py:145] Episode 460, Reward Sum 16.0.
[12-02 10:05:54 MainThread @3974053785.py:145] Episode 470, Reward Sum 16.0.
[12-02 10:05:55 MainThread @3974053785.py:145] Episode 480, Reward Sum 117.0.
[12-02 10:05:55 MainThread @3974053785.py:145] Episode 490, Reward Sum 72.0.
[12-02 10:05:56 MainThread @3974053785.py:154] Test reward: 198.8
[12-02 10:05:56 MainThread @3974053785.py:145] Episode 500, Reward Sum 20.0.
[12-02 10:05:57 MainThread @3974053785.py:145] Episode 510, Reward Sum 62.0.
[12-02 10:05:57 MainThread @3974053785.py:145] Episode 520, Reward Sum 17.0.
[12-02 10:05:58 MainThread @3974053785.py:145] Episode 530, Reward Sum 56.0.
[12-02 10:05:59 MainThread @3974053785.py:145] Episode 540, Reward Sum 101.0.
[12-02 10:05:59 MainThread @3974053785.py:145] Episode 550, Reward Sum 178.0.
[12-02 10:06:00 MainThread @3974053785.py:145] Episode 560, Reward Sum 57.0.
[12-02 10:06:01 MainThread @3974053785.py:145] Episode 570, Reward Sum 158.0.
[12-02 10:06:01 MainThread @3974053785.py:145] Episode 580, Reward Sum 72.0.
[12-02 10:06:02 MainThread @3974053785.py:145] Episode 590, Reward Sum 161.0.
[12-02 10:06:03 MainThread @3974053785.py:154] Test reward: 182.2
[12-02 10:06:03 MainThread @3974053785.py:145] Episode 600, Reward Sum 113.0.
[12-02 10:06:04 MainThread @3974053785.py:145] Episode 610, Reward Sum 112.0.
[12-02 10:06:04 MainThread @3974053785.py:145] Episode 620, Reward Sum 61.0.
[12-02 10:06:05 MainThread @3974053785.py:145] Episode 630, Reward Sum 143.0.
[12-02 10:06:06 MainThread @3974053785.py:145] Episode 640, Reward Sum 156.0.
[12-02 10:06:07 MainThread @3974053785.py:145] Episode 650, Reward Sum 150.0.
[12-02 10:06:08 MainThread @3974053785.py:145] Episode 660, Reward Sum 167.0.
[12-02 10:06:09 MainThread @3974053785.py:145] Episode 670, Reward Sum 200.0.
[12-02 10:06:10 MainThread @3974053785.py:145] Episode 680, Reward Sum 200.0.
[12-02 10:06:10 MainThread @3974053785.py:145] Episode 690, Reward Sum 164.0.
[12-02 10:06:12 MainThread @3974053785.py:154] Test reward: 199.8
[12-02 10:06:12 MainThread @3974053785.py:145] Episode 700, Reward Sum 126.0.
[12-02 10:06:13 MainThread @3974053785.py:145] Episode 710, Reward Sum 164.0.
[12-02 10:06:13 MainThread @3974053785.py:145] Episode 720, Reward Sum 200.0.
[12-02 10:06:14 MainThread @3974053785.py:145] Episode 730, Reward Sum 92.0.
[12-02 10:06:15 MainThread @3974053785.py:145] Episode 740, Reward Sum 200.0.
[12-02 10:06:16 MainThread @3974053785.py:145] Episode 750, Reward Sum 197.0.
[12-02 10:06:17 MainThread @3974053785.py:145] Episode 760, Reward Sum 200.0.
[12-02 10:06:18 MainThread @3974053785.py:145] Episode 770, Reward Sum 178.0.
[12-02 10:06:19 MainThread @3974053785.py:145] Episode 780, Reward Sum 200.0.
[12-02 10:06:20 MainThread @3974053785.py:145] Episode 790, Reward Sum 200.0.
[12-02 10:06:21 MainThread @3974053785.py:154] Test reward: 200.0
[12-02 10:06:21 MainThread @3974053785.py:145] Episode 800, Reward Sum 144.0.
[12-02 10:06:22 MainThread @3974053785.py:145] Episode 810, Reward Sum 195.0.
[12-02 10:06:24 MainThread @3974053785.py:145] Episode 820, Reward Sum 174.0.
[12-02 10:06:25 MainThread @3974053785.py:145] Episode 830, Reward Sum 167.0.
[12-02 10:06:26 MainThread @3974053785.py:145] Episode 840, Reward Sum 125.0.
[12-02 10:06:27 MainThread @3974053785.py:145] Episode 850, Reward Sum 62.0.
[12-02 10:06:27 MainThread @3974053785.py:145] Episode 860, Reward Sum 200.0.
[12-02 10:06:29 MainThread @3974053785.py:145] Episode 870, Reward Sum 137.0.
[12-02 10:06:30 MainThread @3974053785.py:145] Episode 880, Reward Sum 200.0.
[12-02 10:06:31 MainThread @3974053785.py:145] Episode 890, Reward Sum 30.0.
[12-02 10:06:32 MainThread @3974053785.py:154] Test reward: 200.0
[12-02 10:06:32 MainThread @3974053785.py:145] Episode 900, Reward Sum 161.0.
[12-02 10:06:33 MainThread @3974053785.py:145] Episode 910, Reward Sum 200.0.
[12-02 10:06:34 MainThread @3974053785.py:145] Episode 920, Reward Sum 194.0.
[12-02 10:06:36 MainThread @3974053785.py:145] Episode 930, Reward Sum 200.0.
[12-02 10:06:37 MainThread @3974053785.py:145] Episode 940, Reward Sum 200.0.
[12-02 10:06:38 MainThread @3974053785.py:145] Episode 950, Reward Sum 200.0.
[12-02 10:06:39 MainThread @3974053785.py:145] Episode 960, Reward Sum 200.0.
[12-02 10:06:40 MainThread @3974053785.py:145] Episode 970, Reward Sum 200.0.
[12-02 10:06:41 MainThread @3974053785.py:145] Episode 980, Reward Sum 200.0.
[12-02 10:06:42 MainThread @3974053785.py:145] Episode 990, Reward Sum 193.0.
[12-02 10:06:44 MainThread @3974053785.py:154] Test reward: 200.0

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇

)">
下一篇>>