a2c
This commit is contained in:
parent
fd621971e5
commit
6e563fe61a
160
test/test_a2c.py
Normal file
160
test/test_a2c.py
Normal file
@ -0,0 +1,160 @@
|
||||
import gym
|
||||
import time
|
||||
import tqdm
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.model = [
|
||||
nn.Linear(np.prod(state_shape), 128),
|
||||
nn.ReLU(inplace=True)]
|
||||
for i in range(layer_num):
|
||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||
self.actor = self.model + [nn.Linear(128, np.prod(action_shape))]
|
||||
self.critic = self.model + [nn.Linear(128, 1)]
|
||||
self.actor = nn.Sequential(*self.actor)
|
||||
self.critic = nn.Sequential(*self.critic)
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
logits = self.actor(s)
|
||||
value = self.critic(s)
|
||||
return logits, value, None
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=320)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=2)
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# a2c special
|
||||
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||
parser.add_argument('--entropy-coef', type=float, default=0.001)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_a2c(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
reset_after_done=False)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
||||
net = net.to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = A2CPolicy(
|
||||
net, optim, dist, args.gamma,
|
||||
vf_coef=args.vf_coef,
|
||||
entropy_coef=args.entropy_coef)
|
||||
# collector
|
||||
training_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(
|
||||
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
|
||||
# log
|
||||
stat_loss = MovAvg()
|
||||
global_step = 0
|
||||
writer = SummaryWriter(args.logdir)
|
||||
best_epoch = -1
|
||||
best_reward = -1e10
|
||||
start_time = time.time()
|
||||
for epoch in range(1, 1 + args.epoch):
|
||||
desc = f'Epoch #{epoch}'
|
||||
# train
|
||||
policy.train()
|
||||
with tqdm.tqdm(
|
||||
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result = training_collector.collect(
|
||||
n_episode=args.collect_per_step)
|
||||
losses = policy.learn(
|
||||
training_collector.sample(0), args.batch_size)
|
||||
training_collector.reset_buffer()
|
||||
global_step += len(losses)
|
||||
t.update(len(losses))
|
||||
stat_loss.add(losses)
|
||||
writer.add_scalar(
|
||||
'reward', result['reward'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'length', result['length'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'loss', stat_loss.get(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'speed', result['speed'], global_step=global_step)
|
||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.2f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
policy.eval()
|
||||
result = test_collector.collect(n_episode=args.test_num)
|
||||
if best_reward < result['reward']:
|
||||
best_reward = result['reward']
|
||||
best_epoch = epoch
|
||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
if best_reward >= env.spec.reward_threshold:
|
||||
break
|
||||
assert best_reward >= env.spec.reward_threshold
|
||||
training_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
train_cnt = training_collector.collect_step
|
||||
test_cnt = test_collector.collect_step
|
||||
duration = time.time() - start_time
|
||||
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
||||
f'in {duration:.2f}s, '
|
||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
test_collector = Collector(policy, env)
|
||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||
test_collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_a2c()
|
@ -26,8 +26,7 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
|
||||
|
||||
|
||||
def test_fn(size=2560):
|
||||
policy = PGPolicy(
|
||||
None, None, None, discount_factor=0.1, normalized_reward=False)
|
||||
policy = PGPolicy(None, None, None, discount_factor=0.1)
|
||||
fn = policy.process_fn
|
||||
# fn = compute_return_base
|
||||
batch = Batch(
|
||||
@ -36,7 +35,6 @@ def test_fn(size=2560):
|
||||
)
|
||||
batch = fn(batch, None, None)
|
||||
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
|
||||
ans -= ans.mean()
|
||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||
batch = Batch(
|
||||
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
|
||||
@ -44,7 +42,6 @@ def test_fn(size=2560):
|
||||
)
|
||||
batch = fn(batch, None, None)
|
||||
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
|
||||
ans -= ans.mean()
|
||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||
batch = Batch(
|
||||
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
|
||||
@ -52,7 +49,6 @@ def test_fn(size=2560):
|
||||
)
|
||||
batch = fn(batch, None, None)
|
||||
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
||||
ans -= ans.mean()
|
||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||
if __name__ == '__main__':
|
||||
batch = Batch(
|
||||
|
46
tianshou/env/wrapper.py
vendored
46
tianshou/env/wrapper.py
vendored
@ -143,27 +143,31 @@ def worker(parent, p, env_fn_wrapper, reset_after_done):
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
done = False
|
||||
while True:
|
||||
cmd, data = p.recv()
|
||||
if cmd == 'step':
|
||||
if reset_after_done or not done:
|
||||
obs, rew, done, info = env.step(data)
|
||||
if reset_after_done and done:
|
||||
# s_ is useless when episode finishes
|
||||
obs = env.reset()
|
||||
p.send([obs, rew, done, info])
|
||||
elif cmd == 'reset':
|
||||
done = False
|
||||
p.send(env.reset())
|
||||
elif cmd == 'close':
|
||||
p.close()
|
||||
break
|
||||
elif cmd == 'render':
|
||||
p.send(env.render() if hasattr(env, 'render') else None)
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
try:
|
||||
while True:
|
||||
cmd, data = p.recv()
|
||||
if cmd == 'step':
|
||||
if reset_after_done or not done:
|
||||
obs, rew, done, info = env.step(data)
|
||||
if reset_after_done and done:
|
||||
# s_ is useless when episode finishes
|
||||
obs = env.reset()
|
||||
p.send([obs, rew, done, info])
|
||||
elif cmd == 'reset':
|
||||
done = False
|
||||
p.send(env.reset())
|
||||
elif cmd == 'close':
|
||||
p.close()
|
||||
break
|
||||
elif cmd == 'render':
|
||||
p.send(env.render() if hasattr(env, 'render') else None)
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||
else:
|
||||
p.close()
|
||||
raise NotImplementedError
|
||||
except KeyboardInterrupt:
|
||||
p.close()
|
||||
|
||||
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
|
@ -1,9 +1,11 @@
|
||||
from tianshou.policy.base import BasePolicy
|
||||
from tianshou.policy.dqn import DQNPolicy
|
||||
from tianshou.policy.policy_gradient import PGPolicy
|
||||
from tianshou.policy.a2c import A2CPolicy
|
||||
|
||||
__all__ = [
|
||||
'BasePolicy',
|
||||
'DQNPolicy',
|
||||
'PGPolicy',
|
||||
'A2CPolicy',
|
||||
]
|
||||
|
42
tianshou/policy/a2c.py
Normal file
42
tianshou/policy/a2c.py
Normal file
@ -0,0 +1,42 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import PGPolicy
|
||||
|
||||
|
||||
class A2CPolicy(PGPolicy):
|
||||
"""docstring for A2CPolicy"""
|
||||
|
||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||
discount_factor=0.99, vf_coef=.5, entropy_coef=.01):
|
||||
super().__init__(model, optim, dist_fn, discount_factor)
|
||||
self._w_value = vf_coef
|
||||
self._w_entropy = entropy_coef
|
||||
|
||||
def __call__(self, batch, state=None):
|
||||
logits, value, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
logits = F.softmax(logits, dim=1)
|
||||
dist = self.dist_fn(logits)
|
||||
act = dist.sample().detach().cpu().numpy()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist, value=value)
|
||||
|
||||
def learn(self, batch, batch_size=None):
|
||||
losses = []
|
||||
for b in batch.split(batch_size):
|
||||
self.optim.zero_grad()
|
||||
result = self(b)
|
||||
dist = result.dist
|
||||
v = result.value
|
||||
a = torch.tensor(b.act, device=dist.logits.device)
|
||||
r = torch.tensor(b.returns, device=dist.logits.device)
|
||||
actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
||||
critic_loss = (r - v).pow(2).mean()
|
||||
entropy_loss = dist.entropy().mean()
|
||||
loss = actor_loss \
|
||||
+ self._w_value * critic_loss \
|
||||
- self._w_entropy * entropy_loss
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
losses.append(loss.detach().cpu().numpy())
|
||||
return losses
|
@ -11,7 +11,7 @@ class PGPolicy(BasePolicy, nn.Module):
|
||||
"""docstring for PGPolicy"""
|
||||
|
||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||
discount_factor=0.99, normalized_reward=True):
|
||||
discount_factor=0.99):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
@ -19,15 +19,10 @@ class PGPolicy(BasePolicy, nn.Module):
|
||||
self._eps = np.finfo(np.float32).eps.item()
|
||||
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||
self._gamma = discount_factor
|
||||
self._rew_norm = normalized_reward
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
batch_size = len(batch.rew)
|
||||
returns = self._vanilla_returns(batch, batch_size)
|
||||
# returns = self._vectorized_returns(batch, batch_size)
|
||||
returns = returns - returns.mean()
|
||||
if self._rew_norm:
|
||||
returns = returns / (returns.std() + self._eps)
|
||||
returns = self._vanilla_returns(batch)
|
||||
# returns = self._vectorized_returns(batch)
|
||||
batch.update(returns=returns)
|
||||
return batch
|
||||
|
||||
@ -40,6 +35,8 @@ class PGPolicy(BasePolicy, nn.Module):
|
||||
|
||||
def learn(self, batch, batch_size=None):
|
||||
losses = []
|
||||
batch.returns = (batch.returns - batch.returns.mean()) \
|
||||
/ (batch.returns.std() + self._eps)
|
||||
for b in batch.split(batch_size):
|
||||
self.optim.zero_grad()
|
||||
dist = self(b).dist
|
||||
@ -51,21 +48,22 @@ class PGPolicy(BasePolicy, nn.Module):
|
||||
losses.append(loss.detach().cpu().numpy())
|
||||
return losses
|
||||
|
||||
def _vanilla_returns(self, batch, batch_size):
|
||||
def _vanilla_returns(self, batch):
|
||||
returns = batch.rew[:]
|
||||
last = 0
|
||||
for i in range(batch_size - 1, -1, -1):
|
||||
for i in range(len(returns) - 1, -1, -1):
|
||||
if not batch.done[i]:
|
||||
returns[i] += self._gamma * last
|
||||
last = returns[i]
|
||||
return returns
|
||||
|
||||
def _vectorized_returns(self, batch, batch_size):
|
||||
def _vectorized_returns(self, batch):
|
||||
# according to my tests, it is slower than vanilla
|
||||
# import scipy.signal
|
||||
convolve = np.convolve
|
||||
# convolve = scipy.signal.convolve
|
||||
rew = batch.rew[::-1]
|
||||
batch_size = len(rew)
|
||||
gammas = self._gamma ** np.arange(batch_size)
|
||||
c = convolve(rew, gammas)[:batch_size]
|
||||
T = np.where(batch.done[::-1])[0]
|
||||
|
Loading…
x
Reference in New Issue
Block a user