a2c
This commit is contained in:
		
							parent
							
								
									fd621971e5
								
							
						
					
					
						commit
						6e563fe61a
					
				
							
								
								
									
										160
									
								
								test/test_a2c.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										160
									
								
								test/test_a2c.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,160 @@
 | 
				
			|||||||
 | 
					import gym
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import tqdm
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					from torch.utils.tensorboard import SummaryWriter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from tianshou.policy import A2CPolicy
 | 
				
			||||||
 | 
					from tianshou.env import SubprocVectorEnv
 | 
				
			||||||
 | 
					from tianshou.utils import tqdm_config, MovAvg
 | 
				
			||||||
 | 
					from tianshou.data import Collector, ReplayBuffer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Net(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.device = device
 | 
				
			||||||
 | 
					        self.model = [
 | 
				
			||||||
 | 
					            nn.Linear(np.prod(state_shape), 128),
 | 
				
			||||||
 | 
					            nn.ReLU(inplace=True)]
 | 
				
			||||||
 | 
					        for i in range(layer_num):
 | 
				
			||||||
 | 
					            self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
 | 
				
			||||||
 | 
					        self.actor = self.model + [nn.Linear(128, np.prod(action_shape))]
 | 
				
			||||||
 | 
					        self.critic = self.model + [nn.Linear(128, 1)]
 | 
				
			||||||
 | 
					        self.actor = nn.Sequential(*self.actor)
 | 
				
			||||||
 | 
					        self.critic = nn.Sequential(*self.critic)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, s, **kwargs):
 | 
				
			||||||
 | 
					        s = torch.tensor(s, device=self.device, dtype=torch.float)
 | 
				
			||||||
 | 
					        batch = s.shape[0]
 | 
				
			||||||
 | 
					        s = s.view(batch, -1)
 | 
				
			||||||
 | 
					        logits = self.actor(s)
 | 
				
			||||||
 | 
					        value = self.critic(s)
 | 
				
			||||||
 | 
					        return logits, value, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_args():
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser.add_argument('--task', type=str, default='CartPole-v0')
 | 
				
			||||||
 | 
					    parser.add_argument('--seed', type=int, default=1626)
 | 
				
			||||||
 | 
					    parser.add_argument('--buffer-size', type=int, default=20000)
 | 
				
			||||||
 | 
					    parser.add_argument('--lr', type=float, default=1e-3)
 | 
				
			||||||
 | 
					    parser.add_argument('--gamma', type=float, default=0.9)
 | 
				
			||||||
 | 
					    parser.add_argument('--epoch', type=int, default=100)
 | 
				
			||||||
 | 
					    parser.add_argument('--step-per-epoch', type=int, default=320)
 | 
				
			||||||
 | 
					    parser.add_argument('--collect-per-step', type=int, default=10)
 | 
				
			||||||
 | 
					    parser.add_argument('--batch-size', type=int, default=64)
 | 
				
			||||||
 | 
					    parser.add_argument('--layer-num', type=int, default=2)
 | 
				
			||||||
 | 
					    parser.add_argument('--training-num', type=int, default=8)
 | 
				
			||||||
 | 
					    parser.add_argument('--test-num', type=int, default=100)
 | 
				
			||||||
 | 
					    parser.add_argument('--logdir', type=str, default='log')
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        '--device', type=str,
 | 
				
			||||||
 | 
					        default='cuda' if torch.cuda.is_available() else 'cpu')
 | 
				
			||||||
 | 
					    # a2c special
 | 
				
			||||||
 | 
					    parser.add_argument('--vf-coef', type=float, default=0.5)
 | 
				
			||||||
 | 
					    parser.add_argument('--entropy-coef', type=float, default=0.001)
 | 
				
			||||||
 | 
					    args = parser.parse_known_args()[0]
 | 
				
			||||||
 | 
					    return args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_a2c(args=get_args()):
 | 
				
			||||||
 | 
					    env = gym.make(args.task)
 | 
				
			||||||
 | 
					    args.state_shape = env.observation_space.shape or env.observation_space.n
 | 
				
			||||||
 | 
					    args.action_shape = env.action_space.shape or env.action_space.n
 | 
				
			||||||
 | 
					    # train_envs = gym.make(args.task)
 | 
				
			||||||
 | 
					    train_envs = SubprocVectorEnv(
 | 
				
			||||||
 | 
					        [lambda: gym.make(args.task) for _ in range(args.training_num)],
 | 
				
			||||||
 | 
					        reset_after_done=True)
 | 
				
			||||||
 | 
					    # test_envs = gym.make(args.task)
 | 
				
			||||||
 | 
					    test_envs = SubprocVectorEnv(
 | 
				
			||||||
 | 
					        [lambda: gym.make(args.task) for _ in range(args.test_num)],
 | 
				
			||||||
 | 
					        reset_after_done=False)
 | 
				
			||||||
 | 
					    # seed
 | 
				
			||||||
 | 
					    np.random.seed(args.seed)
 | 
				
			||||||
 | 
					    torch.manual_seed(args.seed)
 | 
				
			||||||
 | 
					    train_envs.seed(args.seed)
 | 
				
			||||||
 | 
					    test_envs.seed(args.seed)
 | 
				
			||||||
 | 
					    # model
 | 
				
			||||||
 | 
					    net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
 | 
				
			||||||
 | 
					    net = net.to(args.device)
 | 
				
			||||||
 | 
					    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
 | 
				
			||||||
 | 
					    dist = torch.distributions.Categorical
 | 
				
			||||||
 | 
					    policy = A2CPolicy(
 | 
				
			||||||
 | 
					        net, optim, dist, args.gamma,
 | 
				
			||||||
 | 
					        vf_coef=args.vf_coef,
 | 
				
			||||||
 | 
					        entropy_coef=args.entropy_coef)
 | 
				
			||||||
 | 
					    # collector
 | 
				
			||||||
 | 
					    training_collector = Collector(
 | 
				
			||||||
 | 
					        policy, train_envs, ReplayBuffer(args.buffer_size))
 | 
				
			||||||
 | 
					    test_collector = Collector(
 | 
				
			||||||
 | 
					        policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
 | 
				
			||||||
 | 
					    # log
 | 
				
			||||||
 | 
					    stat_loss = MovAvg()
 | 
				
			||||||
 | 
					    global_step = 0
 | 
				
			||||||
 | 
					    writer = SummaryWriter(args.logdir)
 | 
				
			||||||
 | 
					    best_epoch = -1
 | 
				
			||||||
 | 
					    best_reward = -1e10
 | 
				
			||||||
 | 
					    start_time = time.time()
 | 
				
			||||||
 | 
					    for epoch in range(1, 1 + args.epoch):
 | 
				
			||||||
 | 
					        desc = f'Epoch #{epoch}'
 | 
				
			||||||
 | 
					        # train
 | 
				
			||||||
 | 
					        policy.train()
 | 
				
			||||||
 | 
					        with tqdm.tqdm(
 | 
				
			||||||
 | 
					                total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
 | 
				
			||||||
 | 
					            while t.n < t.total:
 | 
				
			||||||
 | 
					                result = training_collector.collect(
 | 
				
			||||||
 | 
					                    n_episode=args.collect_per_step)
 | 
				
			||||||
 | 
					                losses = policy.learn(
 | 
				
			||||||
 | 
					                    training_collector.sample(0), args.batch_size)
 | 
				
			||||||
 | 
					                training_collector.reset_buffer()
 | 
				
			||||||
 | 
					                global_step += len(losses)
 | 
				
			||||||
 | 
					                t.update(len(losses))
 | 
				
			||||||
 | 
					                stat_loss.add(losses)
 | 
				
			||||||
 | 
					                writer.add_scalar(
 | 
				
			||||||
 | 
					                    'reward', result['reward'], global_step=global_step)
 | 
				
			||||||
 | 
					                writer.add_scalar(
 | 
				
			||||||
 | 
					                    'length', result['length'], global_step=global_step)
 | 
				
			||||||
 | 
					                writer.add_scalar(
 | 
				
			||||||
 | 
					                    'loss', stat_loss.get(), global_step=global_step)
 | 
				
			||||||
 | 
					                writer.add_scalar(
 | 
				
			||||||
 | 
					                    'speed', result['speed'], global_step=global_step)
 | 
				
			||||||
 | 
					                t.set_postfix(loss=f'{stat_loss.get():.6f}',
 | 
				
			||||||
 | 
					                              reward=f'{result["reward"]:.6f}',
 | 
				
			||||||
 | 
					                              length=f'{result["length"]:.2f}',
 | 
				
			||||||
 | 
					                              speed=f'{result["speed"]:.2f}')
 | 
				
			||||||
 | 
					        # eval
 | 
				
			||||||
 | 
					        test_collector.reset_env()
 | 
				
			||||||
 | 
					        test_collector.reset_buffer()
 | 
				
			||||||
 | 
					        policy.eval()
 | 
				
			||||||
 | 
					        result = test_collector.collect(n_episode=args.test_num)
 | 
				
			||||||
 | 
					        if best_reward < result['reward']:
 | 
				
			||||||
 | 
					            best_reward = result['reward']
 | 
				
			||||||
 | 
					            best_epoch = epoch
 | 
				
			||||||
 | 
					        print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
 | 
				
			||||||
 | 
					              f'best_reward: {best_reward:.6f} in #{best_epoch}')
 | 
				
			||||||
 | 
					        if best_reward >= env.spec.reward_threshold:
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					    assert best_reward >= env.spec.reward_threshold
 | 
				
			||||||
 | 
					    training_collector.close()
 | 
				
			||||||
 | 
					    test_collector.close()
 | 
				
			||||||
 | 
					    if __name__ == '__main__':
 | 
				
			||||||
 | 
					        train_cnt = training_collector.collect_step
 | 
				
			||||||
 | 
					        test_cnt = test_collector.collect_step
 | 
				
			||||||
 | 
					        duration = time.time() - start_time
 | 
				
			||||||
 | 
					        print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
 | 
				
			||||||
 | 
					              f'in {duration:.2f}s, '
 | 
				
			||||||
 | 
					              f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
 | 
				
			||||||
 | 
					        # Let's watch its performance!
 | 
				
			||||||
 | 
					        env = gym.make(args.task)
 | 
				
			||||||
 | 
					        test_collector = Collector(policy, env)
 | 
				
			||||||
 | 
					        result = test_collector.collect(n_episode=1, render=1 / 35)
 | 
				
			||||||
 | 
					        print(f'Final reward: {result["reward"]}, length: {result["length"]}')
 | 
				
			||||||
 | 
					        test_collector.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    test_a2c()
 | 
				
			||||||
@ -26,8 +26,7 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_fn(size=2560):
 | 
					def test_fn(size=2560):
 | 
				
			||||||
    policy = PGPolicy(
 | 
					    policy = PGPolicy(None, None, None, discount_factor=0.1)
 | 
				
			||||||
        None, None, None, discount_factor=0.1, normalized_reward=False)
 | 
					 | 
				
			||||||
    fn = policy.process_fn
 | 
					    fn = policy.process_fn
 | 
				
			||||||
    # fn = compute_return_base
 | 
					    # fn = compute_return_base
 | 
				
			||||||
    batch = Batch(
 | 
					    batch = Batch(
 | 
				
			||||||
@ -36,7 +35,6 @@ def test_fn(size=2560):
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
    batch = fn(batch, None, None)
 | 
					    batch = fn(batch, None, None)
 | 
				
			||||||
    ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
 | 
					    ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
 | 
				
			||||||
    ans -= ans.mean()
 | 
					 | 
				
			||||||
    assert abs(batch.returns - ans).sum() <= 1e-5
 | 
					    assert abs(batch.returns - ans).sum() <= 1e-5
 | 
				
			||||||
    batch = Batch(
 | 
					    batch = Batch(
 | 
				
			||||||
        done=np.array([0, 1, 0, 1, 0, 1, 0.]),
 | 
					        done=np.array([0, 1, 0, 1, 0, 1, 0.]),
 | 
				
			||||||
@ -44,7 +42,6 @@ def test_fn(size=2560):
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
    batch = fn(batch, None, None)
 | 
					    batch = fn(batch, None, None)
 | 
				
			||||||
    ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
 | 
					    ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
 | 
				
			||||||
    ans -= ans.mean()
 | 
					 | 
				
			||||||
    assert abs(batch.returns - ans).sum() <= 1e-5
 | 
					    assert abs(batch.returns - ans).sum() <= 1e-5
 | 
				
			||||||
    batch = Batch(
 | 
					    batch = Batch(
 | 
				
			||||||
        done=np.array([0, 1, 0, 1, 0, 0, 1.]),
 | 
					        done=np.array([0, 1, 0, 1, 0, 0, 1.]),
 | 
				
			||||||
@ -52,7 +49,6 @@ def test_fn(size=2560):
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
    batch = fn(batch, None, None)
 | 
					    batch = fn(batch, None, None)
 | 
				
			||||||
    ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
 | 
					    ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
 | 
				
			||||||
    ans -= ans.mean()
 | 
					 | 
				
			||||||
    assert abs(batch.returns - ans).sum() <= 1e-5
 | 
					    assert abs(batch.returns - ans).sum() <= 1e-5
 | 
				
			||||||
    if __name__ == '__main__':
 | 
					    if __name__ == '__main__':
 | 
				
			||||||
        batch = Batch(
 | 
					        batch = Batch(
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										46
									
								
								tianshou/env/wrapper.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										46
									
								
								tianshou/env/wrapper.py
									
									
									
									
										vendored
									
									
								
							@ -143,27 +143,31 @@ def worker(parent, p, env_fn_wrapper, reset_after_done):
 | 
				
			|||||||
    parent.close()
 | 
					    parent.close()
 | 
				
			||||||
    env = env_fn_wrapper.data()
 | 
					    env = env_fn_wrapper.data()
 | 
				
			||||||
    done = False
 | 
					    done = False
 | 
				
			||||||
    while True:
 | 
					    try:
 | 
				
			||||||
        cmd, data = p.recv()
 | 
					        while True:
 | 
				
			||||||
        if cmd == 'step':
 | 
					            cmd, data = p.recv()
 | 
				
			||||||
            if reset_after_done or not done:
 | 
					            if cmd == 'step':
 | 
				
			||||||
                obs, rew, done, info = env.step(data)
 | 
					                if reset_after_done or not done:
 | 
				
			||||||
            if reset_after_done and done:
 | 
					                    obs, rew, done, info = env.step(data)
 | 
				
			||||||
                # s_ is useless when episode finishes
 | 
					                if reset_after_done and done:
 | 
				
			||||||
                obs = env.reset()
 | 
					                    # s_ is useless when episode finishes
 | 
				
			||||||
            p.send([obs, rew, done, info])
 | 
					                    obs = env.reset()
 | 
				
			||||||
        elif cmd == 'reset':
 | 
					                p.send([obs, rew, done, info])
 | 
				
			||||||
            done = False
 | 
					            elif cmd == 'reset':
 | 
				
			||||||
            p.send(env.reset())
 | 
					                done = False
 | 
				
			||||||
        elif cmd == 'close':
 | 
					                p.send(env.reset())
 | 
				
			||||||
            p.close()
 | 
					            elif cmd == 'close':
 | 
				
			||||||
            break
 | 
					                p.close()
 | 
				
			||||||
        elif cmd == 'render':
 | 
					                break
 | 
				
			||||||
            p.send(env.render() if hasattr(env, 'render') else None)
 | 
					            elif cmd == 'render':
 | 
				
			||||||
        elif cmd == 'seed':
 | 
					                p.send(env.render() if hasattr(env, 'render') else None)
 | 
				
			||||||
            p.send(env.seed(data) if hasattr(env, 'seed') else None)
 | 
					            elif cmd == 'seed':
 | 
				
			||||||
        else:
 | 
					                p.send(env.seed(data) if hasattr(env, 'seed') else None)
 | 
				
			||||||
            raise NotImplementedError
 | 
					            else:
 | 
				
			||||||
 | 
					                p.close()
 | 
				
			||||||
 | 
					                raise NotImplementedError
 | 
				
			||||||
 | 
					    except KeyboardInterrupt:
 | 
				
			||||||
 | 
					        p.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SubprocVectorEnv(BaseVectorEnv):
 | 
					class SubprocVectorEnv(BaseVectorEnv):
 | 
				
			||||||
 | 
				
			|||||||
@ -1,9 +1,11 @@
 | 
				
			|||||||
from tianshou.policy.base import BasePolicy
 | 
					from tianshou.policy.base import BasePolicy
 | 
				
			||||||
from tianshou.policy.dqn import DQNPolicy
 | 
					from tianshou.policy.dqn import DQNPolicy
 | 
				
			||||||
from tianshou.policy.policy_gradient import PGPolicy
 | 
					from tianshou.policy.policy_gradient import PGPolicy
 | 
				
			||||||
 | 
					from tianshou.policy.a2c import A2CPolicy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all__ = [
 | 
					__all__ = [
 | 
				
			||||||
    'BasePolicy',
 | 
					    'BasePolicy',
 | 
				
			||||||
    'DQNPolicy',
 | 
					    'DQNPolicy',
 | 
				
			||||||
    'PGPolicy',
 | 
					    'PGPolicy',
 | 
				
			||||||
 | 
					    'A2CPolicy',
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										42
									
								
								tianshou/policy/a2c.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								tianshou/policy/a2c.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,42 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from tianshou.data import Batch
 | 
				
			||||||
 | 
					from tianshou.policy import PGPolicy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class A2CPolicy(PGPolicy):
 | 
				
			||||||
 | 
					    """docstring for A2CPolicy"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
 | 
				
			||||||
 | 
					                 discount_factor=0.99, vf_coef=.5, entropy_coef=.01):
 | 
				
			||||||
 | 
					        super().__init__(model, optim, dist_fn, discount_factor)
 | 
				
			||||||
 | 
					        self._w_value = vf_coef
 | 
				
			||||||
 | 
					        self._w_entropy = entropy_coef
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, batch, state=None):
 | 
				
			||||||
 | 
					        logits, value, h = self.model(batch.obs, state=state, info=batch.info)
 | 
				
			||||||
 | 
					        logits = F.softmax(logits, dim=1)
 | 
				
			||||||
 | 
					        dist = self.dist_fn(logits)
 | 
				
			||||||
 | 
					        act = dist.sample().detach().cpu().numpy()
 | 
				
			||||||
 | 
					        return Batch(logits=logits, act=act, state=h, dist=dist, value=value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def learn(self, batch, batch_size=None):
 | 
				
			||||||
 | 
					        losses = []
 | 
				
			||||||
 | 
					        for b in batch.split(batch_size):
 | 
				
			||||||
 | 
					            self.optim.zero_grad()
 | 
				
			||||||
 | 
					            result = self(b)
 | 
				
			||||||
 | 
					            dist = result.dist
 | 
				
			||||||
 | 
					            v = result.value
 | 
				
			||||||
 | 
					            a = torch.tensor(b.act, device=dist.logits.device)
 | 
				
			||||||
 | 
					            r = torch.tensor(b.returns, device=dist.logits.device)
 | 
				
			||||||
 | 
					            actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
 | 
				
			||||||
 | 
					            critic_loss = (r - v).pow(2).mean()
 | 
				
			||||||
 | 
					            entropy_loss = dist.entropy().mean()
 | 
				
			||||||
 | 
					            loss = actor_loss \
 | 
				
			||||||
 | 
					                + self._w_value * critic_loss \
 | 
				
			||||||
 | 
					                - self._w_entropy * entropy_loss
 | 
				
			||||||
 | 
					            loss.backward()
 | 
				
			||||||
 | 
					            self.optim.step()
 | 
				
			||||||
 | 
					            losses.append(loss.detach().cpu().numpy())
 | 
				
			||||||
 | 
					        return losses
 | 
				
			||||||
@ -11,7 +11,7 @@ class PGPolicy(BasePolicy, nn.Module):
 | 
				
			|||||||
    """docstring for PGPolicy"""
 | 
					    """docstring for PGPolicy"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
 | 
					    def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
 | 
				
			||||||
                 discount_factor=0.99, normalized_reward=True):
 | 
					                 discount_factor=0.99):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.model = model
 | 
					        self.model = model
 | 
				
			||||||
        self.optim = optim
 | 
					        self.optim = optim
 | 
				
			||||||
@ -19,15 +19,10 @@ class PGPolicy(BasePolicy, nn.Module):
 | 
				
			|||||||
        self._eps = np.finfo(np.float32).eps.item()
 | 
					        self._eps = np.finfo(np.float32).eps.item()
 | 
				
			||||||
        assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
 | 
					        assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
 | 
				
			||||||
        self._gamma = discount_factor
 | 
					        self._gamma = discount_factor
 | 
				
			||||||
        self._rew_norm = normalized_reward
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def process_fn(self, batch, buffer, indice):
 | 
					    def process_fn(self, batch, buffer, indice):
 | 
				
			||||||
        batch_size = len(batch.rew)
 | 
					        returns = self._vanilla_returns(batch)
 | 
				
			||||||
        returns = self._vanilla_returns(batch, batch_size)
 | 
					        # returns = self._vectorized_returns(batch)
 | 
				
			||||||
        # returns = self._vectorized_returns(batch, batch_size)
 | 
					 | 
				
			||||||
        returns = returns - returns.mean()
 | 
					 | 
				
			||||||
        if self._rew_norm:
 | 
					 | 
				
			||||||
            returns = returns / (returns.std() + self._eps)
 | 
					 | 
				
			||||||
        batch.update(returns=returns)
 | 
					        batch.update(returns=returns)
 | 
				
			||||||
        return batch
 | 
					        return batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -40,6 +35,8 @@ class PGPolicy(BasePolicy, nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def learn(self, batch, batch_size=None):
 | 
					    def learn(self, batch, batch_size=None):
 | 
				
			||||||
        losses = []
 | 
					        losses = []
 | 
				
			||||||
 | 
					        batch.returns = (batch.returns - batch.returns.mean()) \
 | 
				
			||||||
 | 
					            / (batch.returns.std() + self._eps)
 | 
				
			||||||
        for b in batch.split(batch_size):
 | 
					        for b in batch.split(batch_size):
 | 
				
			||||||
            self.optim.zero_grad()
 | 
					            self.optim.zero_grad()
 | 
				
			||||||
            dist = self(b).dist
 | 
					            dist = self(b).dist
 | 
				
			||||||
@ -51,21 +48,22 @@ class PGPolicy(BasePolicy, nn.Module):
 | 
				
			|||||||
            losses.append(loss.detach().cpu().numpy())
 | 
					            losses.append(loss.detach().cpu().numpy())
 | 
				
			||||||
        return losses
 | 
					        return losses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _vanilla_returns(self, batch, batch_size):
 | 
					    def _vanilla_returns(self, batch):
 | 
				
			||||||
        returns = batch.rew[:]
 | 
					        returns = batch.rew[:]
 | 
				
			||||||
        last = 0
 | 
					        last = 0
 | 
				
			||||||
        for i in range(batch_size - 1, -1, -1):
 | 
					        for i in range(len(returns) - 1, -1, -1):
 | 
				
			||||||
            if not batch.done[i]:
 | 
					            if not batch.done[i]:
 | 
				
			||||||
                returns[i] += self._gamma * last
 | 
					                returns[i] += self._gamma * last
 | 
				
			||||||
            last = returns[i]
 | 
					            last = returns[i]
 | 
				
			||||||
        return returns
 | 
					        return returns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _vectorized_returns(self, batch, batch_size):
 | 
					    def _vectorized_returns(self, batch):
 | 
				
			||||||
        # according to my tests, it is slower than vanilla
 | 
					        # according to my tests, it is slower than vanilla
 | 
				
			||||||
        # import scipy.signal
 | 
					        # import scipy.signal
 | 
				
			||||||
        convolve = np.convolve
 | 
					        convolve = np.convolve
 | 
				
			||||||
        # convolve = scipy.signal.convolve
 | 
					        # convolve = scipy.signal.convolve
 | 
				
			||||||
        rew = batch.rew[::-1]
 | 
					        rew = batch.rew[::-1]
 | 
				
			||||||
 | 
					        batch_size = len(rew)
 | 
				
			||||||
        gammas = self._gamma ** np.arange(batch_size)
 | 
					        gammas = self._gamma ** np.arange(batch_size)
 | 
				
			||||||
        c = convolve(rew, gammas)[:batch_size]
 | 
					        c = convolve(rew, gammas)[:batch_size]
 | 
				
			||||||
        T = np.where(batch.done[::-1])[0]
 | 
					        T = np.where(batch.done[::-1])[0]
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user