This commit is contained in:
Trinkle23897 2020-03-17 20:22:37 +08:00
parent fd621971e5
commit 6e563fe61a
6 changed files with 239 additions and 37 deletions

160
test/test_a2c.py Normal file
View File

@ -0,0 +1,160 @@
import gym
import time
import tqdm
import torch
import argparse
import numpy as np
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import A2CPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils import tqdm_config, MovAvg
from tianshou.data import Collector, ReplayBuffer
class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.actor = self.model + [nn.Linear(128, np.prod(action_shape))]
self.critic = self.model + [nn.Linear(128, 1)]
self.actor = nn.Sequential(*self.actor)
self.critic = nn.Sequential(*self.critic)
def forward(self, s, **kwargs):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.actor(s)
value = self.critic(s)
return logits, value, None
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=320)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=2)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
# a2c special
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--entropy-coef', type=float, default=0.001)
args = parser.parse_known_args()[0]
return args
def test_a2c(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
net, optim, dist, args.gamma,
vf_coef=args.vf_coef,
entropy_coef=args.entropy_coef)
# collector
training_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
# log
stat_loss = MovAvg()
global_step = 0
writer = SummaryWriter(args.logdir)
best_epoch = -1
best_reward = -1e10
start_time = time.time()
for epoch in range(1, 1 + args.epoch):
desc = f'Epoch #{epoch}'
# train
policy.train()
with tqdm.tqdm(
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
while t.n < t.total:
result = training_collector.collect(
n_episode=args.collect_per_step)
losses = policy.learn(
training_collector.sample(0), args.batch_size)
training_collector.reset_buffer()
global_step += len(losses)
t.update(len(losses))
stat_loss.add(losses)
writer.add_scalar(
'reward', result['reward'], global_step=global_step)
writer.add_scalar(
'length', result['length'], global_step=global_step)
writer.add_scalar(
'loss', stat_loss.get(), global_step=global_step)
writer.add_scalar(
'speed', result['speed'], global_step=global_step)
t.set_postfix(loss=f'{stat_loss.get():.6f}',
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}')
# eval
test_collector.reset_env()
test_collector.reset_buffer()
policy.eval()
result = test_collector.collect(n_episode=args.test_num)
if best_reward < result['reward']:
best_reward = result['reward']
best_epoch = epoch
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if best_reward >= env.spec.reward_threshold:
break
assert best_reward >= env.spec.reward_threshold
training_collector.close()
test_collector.close()
if __name__ == '__main__':
train_cnt = training_collector.collect_step
test_cnt = test_collector.collect_step
duration = time.time() - start_time
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
f'in {duration:.2f}s, '
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
# Let's watch its performance!
env = gym.make(args.task)
test_collector = Collector(policy, env)
result = test_collector.collect(n_episode=1, render=1 / 35)
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
test_collector.close()
if __name__ == '__main__':
test_a2c()

View File

@ -26,8 +26,7 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
def test_fn(size=2560):
policy = PGPolicy(
None, None, None, discount_factor=0.1, normalized_reward=False)
policy = PGPolicy(None, None, None, discount_factor=0.1)
fn = policy.process_fn
# fn = compute_return_base
batch = Batch(
@ -36,7 +35,6 @@ def test_fn(size=2560):
)
batch = fn(batch, None, None)
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
ans -= ans.mean()
assert abs(batch.returns - ans).sum() <= 1e-5
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
@ -44,7 +42,6 @@ def test_fn(size=2560):
)
batch = fn(batch, None, None)
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
ans -= ans.mean()
assert abs(batch.returns - ans).sum() <= 1e-5
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
@ -52,7 +49,6 @@ def test_fn(size=2560):
)
batch = fn(batch, None, None)
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
ans -= ans.mean()
assert abs(batch.returns - ans).sum() <= 1e-5
if __name__ == '__main__':
batch = Batch(

View File

@ -143,27 +143,31 @@ def worker(parent, p, env_fn_wrapper, reset_after_done):
parent.close()
env = env_fn_wrapper.data()
done = False
while True:
cmd, data = p.recv()
if cmd == 'step':
if reset_after_done or not done:
obs, rew, done, info = env.step(data)
if reset_after_done and done:
# s_ is useless when episode finishes
obs = env.reset()
p.send([obs, rew, done, info])
elif cmd == 'reset':
done = False
p.send(env.reset())
elif cmd == 'close':
p.close()
break
elif cmd == 'render':
p.send(env.render() if hasattr(env, 'render') else None)
elif cmd == 'seed':
p.send(env.seed(data) if hasattr(env, 'seed') else None)
else:
raise NotImplementedError
try:
while True:
cmd, data = p.recv()
if cmd == 'step':
if reset_after_done or not done:
obs, rew, done, info = env.step(data)
if reset_after_done and done:
# s_ is useless when episode finishes
obs = env.reset()
p.send([obs, rew, done, info])
elif cmd == 'reset':
done = False
p.send(env.reset())
elif cmd == 'close':
p.close()
break
elif cmd == 'render':
p.send(env.render() if hasattr(env, 'render') else None)
elif cmd == 'seed':
p.send(env.seed(data) if hasattr(env, 'seed') else None)
else:
p.close()
raise NotImplementedError
except KeyboardInterrupt:
p.close()
class SubprocVectorEnv(BaseVectorEnv):

View File

@ -1,9 +1,11 @@
from tianshou.policy.base import BasePolicy
from tianshou.policy.dqn import DQNPolicy
from tianshou.policy.policy_gradient import PGPolicy
from tianshou.policy.a2c import A2CPolicy
__all__ = [
'BasePolicy',
'DQNPolicy',
'PGPolicy',
'A2CPolicy',
]

42
tianshou/policy/a2c.py Normal file
View File

@ -0,0 +1,42 @@
import torch
import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import PGPolicy
class A2CPolicy(PGPolicy):
"""docstring for A2CPolicy"""
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
discount_factor=0.99, vf_coef=.5, entropy_coef=.01):
super().__init__(model, optim, dist_fn, discount_factor)
self._w_value = vf_coef
self._w_entropy = entropy_coef
def __call__(self, batch, state=None):
logits, value, h = self.model(batch.obs, state=state, info=batch.info)
logits = F.softmax(logits, dim=1)
dist = self.dist_fn(logits)
act = dist.sample().detach().cpu().numpy()
return Batch(logits=logits, act=act, state=h, dist=dist, value=value)
def learn(self, batch, batch_size=None):
losses = []
for b in batch.split(batch_size):
self.optim.zero_grad()
result = self(b)
dist = result.dist
v = result.value
a = torch.tensor(b.act, device=dist.logits.device)
r = torch.tensor(b.returns, device=dist.logits.device)
actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
critic_loss = (r - v).pow(2).mean()
entropy_loss = dist.entropy().mean()
loss = actor_loss \
+ self._w_value * critic_loss \
- self._w_entropy * entropy_loss
loss.backward()
self.optim.step()
losses.append(loss.detach().cpu().numpy())
return losses

View File

@ -11,7 +11,7 @@ class PGPolicy(BasePolicy, nn.Module):
"""docstring for PGPolicy"""
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
discount_factor=0.99, normalized_reward=True):
discount_factor=0.99):
super().__init__()
self.model = model
self.optim = optim
@ -19,15 +19,10 @@ class PGPolicy(BasePolicy, nn.Module):
self._eps = np.finfo(np.float32).eps.item()
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
self._gamma = discount_factor
self._rew_norm = normalized_reward
def process_fn(self, batch, buffer, indice):
batch_size = len(batch.rew)
returns = self._vanilla_returns(batch, batch_size)
# returns = self._vectorized_returns(batch, batch_size)
returns = returns - returns.mean()
if self._rew_norm:
returns = returns / (returns.std() + self._eps)
returns = self._vanilla_returns(batch)
# returns = self._vectorized_returns(batch)
batch.update(returns=returns)
return batch
@ -40,6 +35,8 @@ class PGPolicy(BasePolicy, nn.Module):
def learn(self, batch, batch_size=None):
losses = []
batch.returns = (batch.returns - batch.returns.mean()) \
/ (batch.returns.std() + self._eps)
for b in batch.split(batch_size):
self.optim.zero_grad()
dist = self(b).dist
@ -51,21 +48,22 @@ class PGPolicy(BasePolicy, nn.Module):
losses.append(loss.detach().cpu().numpy())
return losses
def _vanilla_returns(self, batch, batch_size):
def _vanilla_returns(self, batch):
returns = batch.rew[:]
last = 0
for i in range(batch_size - 1, -1, -1):
for i in range(len(returns) - 1, -1, -1):
if not batch.done[i]:
returns[i] += self._gamma * last
last = returns[i]
return returns
def _vectorized_returns(self, batch, batch_size):
def _vectorized_returns(self, batch):
# according to my tests, it is slower than vanilla
# import scipy.signal
convolve = np.convolve
# convolve = scipy.signal.convolve
rew = batch.rew[::-1]
batch_size = len(rew)
gammas = self._gamma ** np.arange(batch_size)
c = convolve(rew, gammas)[:batch_size]
T = np.where(batch.done[::-1])[0]