a2c
This commit is contained in:
parent
fd621971e5
commit
6e563fe61a
160
test/test_a2c.py
Normal file
160
test/test_a2c.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import gym
|
||||||
|
import time
|
||||||
|
import tqdm
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.policy import A2CPolicy
|
||||||
|
from tianshou.env import SubprocVectorEnv
|
||||||
|
from tianshou.utils import tqdm_config, MovAvg
|
||||||
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.model = [
|
||||||
|
nn.Linear(np.prod(state_shape), 128),
|
||||||
|
nn.ReLU(inplace=True)]
|
||||||
|
for i in range(layer_num):
|
||||||
|
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||||
|
self.actor = self.model + [nn.Linear(128, np.prod(action_shape))]
|
||||||
|
self.critic = self.model + [nn.Linear(128, 1)]
|
||||||
|
self.actor = nn.Sequential(*self.actor)
|
||||||
|
self.critic = nn.Sequential(*self.critic)
|
||||||
|
|
||||||
|
def forward(self, s, **kwargs):
|
||||||
|
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||||
|
batch = s.shape[0]
|
||||||
|
s = s.view(batch, -1)
|
||||||
|
logits = self.actor(s)
|
||||||
|
value = self.critic(s)
|
||||||
|
return logits, value, None
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=320)
|
||||||
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument('--layer-num', type=int, default=2)
|
||||||
|
parser.add_argument('--training-num', type=int, default=8)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str,
|
||||||
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
# a2c special
|
||||||
|
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||||
|
parser.add_argument('--entropy-coef', type=float, default=0.001)
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_a2c(args=get_args()):
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
train_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||||
|
reset_after_done=True)
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||||
|
reset_after_done=False)
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
||||||
|
net = net.to(args.device)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
dist = torch.distributions.Categorical
|
||||||
|
policy = A2CPolicy(
|
||||||
|
net, optim, dist, args.gamma,
|
||||||
|
vf_coef=args.vf_coef,
|
||||||
|
entropy_coef=args.entropy_coef)
|
||||||
|
# collector
|
||||||
|
training_collector = Collector(
|
||||||
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
|
test_collector = Collector(
|
||||||
|
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
|
||||||
|
# log
|
||||||
|
stat_loss = MovAvg()
|
||||||
|
global_step = 0
|
||||||
|
writer = SummaryWriter(args.logdir)
|
||||||
|
best_epoch = -1
|
||||||
|
best_reward = -1e10
|
||||||
|
start_time = time.time()
|
||||||
|
for epoch in range(1, 1 + args.epoch):
|
||||||
|
desc = f'Epoch #{epoch}'
|
||||||
|
# train
|
||||||
|
policy.train()
|
||||||
|
with tqdm.tqdm(
|
||||||
|
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||||
|
while t.n < t.total:
|
||||||
|
result = training_collector.collect(
|
||||||
|
n_episode=args.collect_per_step)
|
||||||
|
losses = policy.learn(
|
||||||
|
training_collector.sample(0), args.batch_size)
|
||||||
|
training_collector.reset_buffer()
|
||||||
|
global_step += len(losses)
|
||||||
|
t.update(len(losses))
|
||||||
|
stat_loss.add(losses)
|
||||||
|
writer.add_scalar(
|
||||||
|
'reward', result['reward'], global_step=global_step)
|
||||||
|
writer.add_scalar(
|
||||||
|
'length', result['length'], global_step=global_step)
|
||||||
|
writer.add_scalar(
|
||||||
|
'loss', stat_loss.get(), global_step=global_step)
|
||||||
|
writer.add_scalar(
|
||||||
|
'speed', result['speed'], global_step=global_step)
|
||||||
|
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||||
|
reward=f'{result["reward"]:.6f}',
|
||||||
|
length=f'{result["length"]:.2f}',
|
||||||
|
speed=f'{result["speed"]:.2f}')
|
||||||
|
# eval
|
||||||
|
test_collector.reset_env()
|
||||||
|
test_collector.reset_buffer()
|
||||||
|
policy.eval()
|
||||||
|
result = test_collector.collect(n_episode=args.test_num)
|
||||||
|
if best_reward < result['reward']:
|
||||||
|
best_reward = result['reward']
|
||||||
|
best_epoch = epoch
|
||||||
|
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
||||||
|
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||||
|
if best_reward >= env.spec.reward_threshold:
|
||||||
|
break
|
||||||
|
assert best_reward >= env.spec.reward_threshold
|
||||||
|
training_collector.close()
|
||||||
|
test_collector.close()
|
||||||
|
if __name__ == '__main__':
|
||||||
|
train_cnt = training_collector.collect_step
|
||||||
|
test_cnt = test_collector.collect_step
|
||||||
|
duration = time.time() - start_time
|
||||||
|
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
||||||
|
f'in {duration:.2f}s, '
|
||||||
|
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||||
|
# Let's watch its performance!
|
||||||
|
env = gym.make(args.task)
|
||||||
|
test_collector = Collector(policy, env)
|
||||||
|
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||||
|
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||||
|
test_collector.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_a2c()
|
@ -26,8 +26,7 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
|
|||||||
|
|
||||||
|
|
||||||
def test_fn(size=2560):
|
def test_fn(size=2560):
|
||||||
policy = PGPolicy(
|
policy = PGPolicy(None, None, None, discount_factor=0.1)
|
||||||
None, None, None, discount_factor=0.1, normalized_reward=False)
|
|
||||||
fn = policy.process_fn
|
fn = policy.process_fn
|
||||||
# fn = compute_return_base
|
# fn = compute_return_base
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
@ -36,7 +35,6 @@ def test_fn(size=2560):
|
|||||||
)
|
)
|
||||||
batch = fn(batch, None, None)
|
batch = fn(batch, None, None)
|
||||||
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
|
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
|
||||||
ans -= ans.mean()
|
|
||||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
|
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
|
||||||
@ -44,7 +42,6 @@ def test_fn(size=2560):
|
|||||||
)
|
)
|
||||||
batch = fn(batch, None, None)
|
batch = fn(batch, None, None)
|
||||||
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
|
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
|
||||||
ans -= ans.mean()
|
|
||||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
|
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
|
||||||
@ -52,7 +49,6 @@ def test_fn(size=2560):
|
|||||||
)
|
)
|
||||||
batch = fn(batch, None, None)
|
batch = fn(batch, None, None)
|
||||||
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
||||||
ans -= ans.mean()
|
|
||||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
|
46
tianshou/env/wrapper.py
vendored
46
tianshou/env/wrapper.py
vendored
@ -143,27 +143,31 @@ def worker(parent, p, env_fn_wrapper, reset_after_done):
|
|||||||
parent.close()
|
parent.close()
|
||||||
env = env_fn_wrapper.data()
|
env = env_fn_wrapper.data()
|
||||||
done = False
|
done = False
|
||||||
while True:
|
try:
|
||||||
cmd, data = p.recv()
|
while True:
|
||||||
if cmd == 'step':
|
cmd, data = p.recv()
|
||||||
if reset_after_done or not done:
|
if cmd == 'step':
|
||||||
obs, rew, done, info = env.step(data)
|
if reset_after_done or not done:
|
||||||
if reset_after_done and done:
|
obs, rew, done, info = env.step(data)
|
||||||
# s_ is useless when episode finishes
|
if reset_after_done and done:
|
||||||
obs = env.reset()
|
# s_ is useless when episode finishes
|
||||||
p.send([obs, rew, done, info])
|
obs = env.reset()
|
||||||
elif cmd == 'reset':
|
p.send([obs, rew, done, info])
|
||||||
done = False
|
elif cmd == 'reset':
|
||||||
p.send(env.reset())
|
done = False
|
||||||
elif cmd == 'close':
|
p.send(env.reset())
|
||||||
p.close()
|
elif cmd == 'close':
|
||||||
break
|
p.close()
|
||||||
elif cmd == 'render':
|
break
|
||||||
p.send(env.render() if hasattr(env, 'render') else None)
|
elif cmd == 'render':
|
||||||
elif cmd == 'seed':
|
p.send(env.render() if hasattr(env, 'render') else None)
|
||||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
elif cmd == 'seed':
|
||||||
else:
|
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||||
raise NotImplementedError
|
else:
|
||||||
|
p.close()
|
||||||
|
raise NotImplementedError
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
p.close()
|
||||||
|
|
||||||
|
|
||||||
class SubprocVectorEnv(BaseVectorEnv):
|
class SubprocVectorEnv(BaseVectorEnv):
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
from tianshou.policy.base import BasePolicy
|
from tianshou.policy.base import BasePolicy
|
||||||
from tianshou.policy.dqn import DQNPolicy
|
from tianshou.policy.dqn import DQNPolicy
|
||||||
from tianshou.policy.policy_gradient import PGPolicy
|
from tianshou.policy.policy_gradient import PGPolicy
|
||||||
|
from tianshou.policy.a2c import A2CPolicy
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BasePolicy',
|
'BasePolicy',
|
||||||
'DQNPolicy',
|
'DQNPolicy',
|
||||||
'PGPolicy',
|
'PGPolicy',
|
||||||
|
'A2CPolicy',
|
||||||
]
|
]
|
||||||
|
42
tianshou/policy/a2c.py
Normal file
42
tianshou/policy/a2c.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from tianshou.data import Batch
|
||||||
|
from tianshou.policy import PGPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class A2CPolicy(PGPolicy):
|
||||||
|
"""docstring for A2CPolicy"""
|
||||||
|
|
||||||
|
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||||
|
discount_factor=0.99, vf_coef=.5, entropy_coef=.01):
|
||||||
|
super().__init__(model, optim, dist_fn, discount_factor)
|
||||||
|
self._w_value = vf_coef
|
||||||
|
self._w_entropy = entropy_coef
|
||||||
|
|
||||||
|
def __call__(self, batch, state=None):
|
||||||
|
logits, value, h = self.model(batch.obs, state=state, info=batch.info)
|
||||||
|
logits = F.softmax(logits, dim=1)
|
||||||
|
dist = self.dist_fn(logits)
|
||||||
|
act = dist.sample().detach().cpu().numpy()
|
||||||
|
return Batch(logits=logits, act=act, state=h, dist=dist, value=value)
|
||||||
|
|
||||||
|
def learn(self, batch, batch_size=None):
|
||||||
|
losses = []
|
||||||
|
for b in batch.split(batch_size):
|
||||||
|
self.optim.zero_grad()
|
||||||
|
result = self(b)
|
||||||
|
dist = result.dist
|
||||||
|
v = result.value
|
||||||
|
a = torch.tensor(b.act, device=dist.logits.device)
|
||||||
|
r = torch.tensor(b.returns, device=dist.logits.device)
|
||||||
|
actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
||||||
|
critic_loss = (r - v).pow(2).mean()
|
||||||
|
entropy_loss = dist.entropy().mean()
|
||||||
|
loss = actor_loss \
|
||||||
|
+ self._w_value * critic_loss \
|
||||||
|
- self._w_entropy * entropy_loss
|
||||||
|
loss.backward()
|
||||||
|
self.optim.step()
|
||||||
|
losses.append(loss.detach().cpu().numpy())
|
||||||
|
return losses
|
@ -11,7 +11,7 @@ class PGPolicy(BasePolicy, nn.Module):
|
|||||||
"""docstring for PGPolicy"""
|
"""docstring for PGPolicy"""
|
||||||
|
|
||||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||||
discount_factor=0.99, normalized_reward=True):
|
discount_factor=0.99):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
@ -19,15 +19,10 @@ class PGPolicy(BasePolicy, nn.Module):
|
|||||||
self._eps = np.finfo(np.float32).eps.item()
|
self._eps = np.finfo(np.float32).eps.item()
|
||||||
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||||
self._gamma = discount_factor
|
self._gamma = discount_factor
|
||||||
self._rew_norm = normalized_reward
|
|
||||||
|
|
||||||
def process_fn(self, batch, buffer, indice):
|
def process_fn(self, batch, buffer, indice):
|
||||||
batch_size = len(batch.rew)
|
returns = self._vanilla_returns(batch)
|
||||||
returns = self._vanilla_returns(batch, batch_size)
|
# returns = self._vectorized_returns(batch)
|
||||||
# returns = self._vectorized_returns(batch, batch_size)
|
|
||||||
returns = returns - returns.mean()
|
|
||||||
if self._rew_norm:
|
|
||||||
returns = returns / (returns.std() + self._eps)
|
|
||||||
batch.update(returns=returns)
|
batch.update(returns=returns)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@ -40,6 +35,8 @@ class PGPolicy(BasePolicy, nn.Module):
|
|||||||
|
|
||||||
def learn(self, batch, batch_size=None):
|
def learn(self, batch, batch_size=None):
|
||||||
losses = []
|
losses = []
|
||||||
|
batch.returns = (batch.returns - batch.returns.mean()) \
|
||||||
|
/ (batch.returns.std() + self._eps)
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
dist = self(b).dist
|
dist = self(b).dist
|
||||||
@ -51,21 +48,22 @@ class PGPolicy(BasePolicy, nn.Module):
|
|||||||
losses.append(loss.detach().cpu().numpy())
|
losses.append(loss.detach().cpu().numpy())
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def _vanilla_returns(self, batch, batch_size):
|
def _vanilla_returns(self, batch):
|
||||||
returns = batch.rew[:]
|
returns = batch.rew[:]
|
||||||
last = 0
|
last = 0
|
||||||
for i in range(batch_size - 1, -1, -1):
|
for i in range(len(returns) - 1, -1, -1):
|
||||||
if not batch.done[i]:
|
if not batch.done[i]:
|
||||||
returns[i] += self._gamma * last
|
returns[i] += self._gamma * last
|
||||||
last = returns[i]
|
last = returns[i]
|
||||||
return returns
|
return returns
|
||||||
|
|
||||||
def _vectorized_returns(self, batch, batch_size):
|
def _vectorized_returns(self, batch):
|
||||||
# according to my tests, it is slower than vanilla
|
# according to my tests, it is slower than vanilla
|
||||||
# import scipy.signal
|
# import scipy.signal
|
||||||
convolve = np.convolve
|
convolve = np.convolve
|
||||||
# convolve = scipy.signal.convolve
|
# convolve = scipy.signal.convolve
|
||||||
rew = batch.rew[::-1]
|
rew = batch.rew[::-1]
|
||||||
|
batch_size = len(rew)
|
||||||
gammas = self._gamma ** np.arange(batch_size)
|
gammas = self._gamma ** np.arange(batch_size)
|
||||||
c = convolve(rew, gammas)[:batch_size]
|
c = convolve(rew, gammas)[:batch_size]
|
||||||
T = np.where(batch.done[::-1])[0]
|
T = np.where(batch.done[::-1])[0]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user