finish pg

This commit is contained in:
Trinkle23897 2020-03-17 11:37:31 +08:00
parent 8b0b970c9b
commit 39de63592f
12 changed files with 362 additions and 57 deletions

View File

@ -41,8 +41,8 @@ setup(
'gym', 'gym',
'tqdm', 'tqdm',
'numpy', 'numpy',
'torch>=1.2.0', # for supporting tensorboard
'cloudpickle', 'cloudpickle',
'tensorboard', 'tensorboard',
'torch>=1.2.0', # for supporting tensorboard
], ],
) )

View File

@ -14,6 +14,9 @@ def test_batch():
assert batch[0].obs == batch[1].obs assert batch[0].obs == batch[1].obs
with pytest.raises(IndexError): with pytest.raises(IndexError):
batch[2] batch[2]
batch.obs = np.arange(5)
for i, b in enumerate(batch.split(1)):
assert b.obs == batch[i].obs
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -26,9 +26,7 @@ class Net(nn.Module):
self.model = nn.Sequential(*self.model) self.model = nn.Sequential(*self.model)
def forward(self, s, **kwargs): def forward(self, s, **kwargs):
if not isinstance(s, torch.Tensor): s = torch.tensor(s, device=self.device, dtype=torch.float)
s = torch.Tensor(s)
s = s.to(self.device)
batch = s.shape[0] batch = s.shape[0]
q = self.model(s.view(batch, -1)) q = self.model(s.view(batch, -1))
return q, None return q, None
@ -63,6 +61,7 @@ def test_dqn(args=get_args()):
env = gym.make(args.task) env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv( train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)], [lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True) reset_after_done=True)
@ -118,7 +117,7 @@ def test_dqn(args=get_args()):
'speed', result['speed'], global_step=global_step) 'speed', result['speed'], global_step=global_step)
t.set_postfix(loss=f'{stat_loss.get():.6f}', t.set_postfix(loss=f'{stat_loss.get():.6f}',
reward=f'{result["reward"]:.6f}', reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.6f}', length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}') speed=f'{result["speed"]:.2f}')
# eval # eval
test_collector.reset_env() test_collector.reset_env()
@ -131,9 +130,11 @@ def test_dqn(args=get_args()):
best_epoch = epoch best_epoch = epoch
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, ' print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}') f'best_reward: {best_reward:.6f} in #{best_epoch}')
if args.task == 'CartPole-v0' and best_reward >= 200: if best_reward >= env.spec.reward_threshold:
break break
assert best_reward >= 200 assert best_reward >= env.spec.reward_threshold
training_collector.close()
test_collector.close()
if __name__ == '__main__': if __name__ == '__main__':
train_cnt = training_collector.collect_step train_cnt = training_collector.collect_step
test_cnt = test_collector.collect_step test_cnt = test_collector.collect_step
@ -143,18 +144,10 @@ def test_dqn(args=get_args()):
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s') f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
obs = env.reset() test_collector = Collector(policy, env, ReplayBuffer(1))
done = False result = test_collector.collect(n_episode=1, render=1 / 35)
total = 0 print(f'Final reward: {result["reward"]}, length: {result["length"]}')
while not done: test_collector.close()
q, _ = net([obs])
action = q.max(dim=1)[1]
obs, rew, done, info = env.step(action[0].detach().cpu().numpy())
total += rew
env.render()
time.sleep(1 / 35)
env.close()
print(f'Final test: {total}')
if __name__ == '__main__': if __name__ == '__main__':

207
test/test_pg.py Normal file
View File

@ -0,0 +1,207 @@
import gym
import time
import tqdm
import torch
import argparse
import numpy as np
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PGPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils import tqdm_config, MovAvg
from tianshou.data import Batch, Collector, ReplayBuffer
def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
returns = np.zeros_like(batch.rew)
last = 0
for i in reversed(range(len(batch.rew))):
returns[i] = batch.rew[i]
if not batch.done[i]:
returns[i] += last * gamma
last = returns[i]
batch.update(returns=returns)
return batch
def test_fn(size=2560):
policy = PGPolicy(
None, None, None, discount_factor=0.1, normalized_reward=False)
fn = policy.process_fn
# fn = compute_return_base
batch = Batch(
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
)
batch = fn(batch, None, None)
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
ans -= ans.mean()
assert abs(batch.returns - ans).sum() <= 1e-5
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
)
batch = fn(batch, None, None)
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
ans -= ans.mean()
assert abs(batch.returns - ans).sum() <= 1e-5
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
)
batch = fn(batch, None, None)
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
ans -= ans.mean()
assert abs(batch.returns - ans).sum() <= 1e-5
if __name__ == '__main__':
batch = Batch(
done=np.random.randint(100, size=size) == 0,
rew=np.random.random(size),
)
cnt = 3000
t = time.time()
for _ in range(cnt):
compute_return_base(batch)
print(f'vanilla: {(time.time() - t) / cnt}')
t = time.time()
for _ in range(cnt):
policy.process_fn(batch, None, None)
print(f'policy: {(time.time() - t) / cnt}')
class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.model += [nn.Linear(128, np.prod(action_shape))]
self.model = nn.Sequential(*self.model)
def forward(self, s, **kwargs):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
logits = self.model(s.view(batch, -1))
return logits, None
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=320)
parser.add_argument('--collect-per-step', type=int, default=5)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=20)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def test_pg(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
dist = torch.distributions.Categorical
policy = PGPolicy(net, optim, dist, args.gamma)
# collector
training_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
# log
stat_loss = MovAvg()
global_step = 0
writer = SummaryWriter(args.logdir)
best_epoch = -1
best_reward = -1e10
start_time = time.time()
for epoch in range(1, 1 + args.epoch):
desc = f"Epoch #{epoch}"
# train
policy.train()
with tqdm.tqdm(
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
while t.n < t.total:
result = training_collector.collect(
n_episode=args.collect_per_step)
losses = policy.learn(
training_collector.sample(0), args.batch_size)
global_step += len(losses)
t.update(len(losses))
stat_loss.add(losses)
writer.add_scalar(
'reward', result['reward'], global_step=global_step)
writer.add_scalar(
'length', result['length'], global_step=global_step)
writer.add_scalar(
'loss', stat_loss.get(), global_step=global_step)
writer.add_scalar(
'speed', result['speed'], global_step=global_step)
t.set_postfix(loss=f'{stat_loss.get():.6f}',
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}')
# eval
test_collector.reset_env()
test_collector.reset_buffer()
policy.eval()
result = test_collector.collect(n_episode=args.test_num)
if best_reward < result['reward']:
best_reward = result['reward']
best_epoch = epoch
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if best_reward >= env.spec.reward_threshold:
break
assert best_reward >= env.spec.reward_threshold
training_collector.close()
test_collector.close()
if __name__ == '__main__':
train_cnt = training_collector.collect_step
test_cnt = test_collector.collect_step
duration = time.time() - start_time
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
f'in {duration:.2f}s, '
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
# Let's watch its performance!
env = gym.make(args.task)
test_collector = Collector(policy, env, ReplayBuffer(1))
result = test_collector.collect(n_episode=1, render=1 / 35)
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
test_collector.close()
if __name__ == '__main__':
# test_fn()
test_pg()

View File

@ -38,3 +38,14 @@ class Batch(object):
raise TypeError( raise TypeError(
'No support for append with type {} in class Batch.' 'No support for append with type {} in class Batch.'
.format(type(batch.__dict__[k]))) .format(type(batch.__dict__[k])))
def split(self, size=None):
length = min([
len(self.__dict__[k]) for k in self.__dict__.keys()
if self.__dict__[k] is not None])
if size is None:
size = length
temp = 0
while temp < length:
yield self[temp:temp + size]
temp += size

View File

@ -59,7 +59,10 @@ class ReplayBuffer(object):
if batch_size > 0: if batch_size > 0:
indice = np.random.choice(self._size, batch_size) indice = np.random.choice(self._size, batch_size)
else: else:
indice = np.arange(self._size) indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
return Batch( return Batch(
obs=self.obs[indice], obs=self.obs[indice],
act=self.act[indice], act=self.act[indice],

View File

@ -121,8 +121,10 @@ class VectorEnv(BaseVectorEnv):
np.stack(self._done), np.stack(self._info) np.stack(self._done), np.stack(self._info)
def seed(self, seed=None): def seed(self, seed=None):
if np.isscalar(seed) or seed is None: if np.isscalar(seed):
seed = [seed for _ in range(self.env_num)] seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
for e, s in zip(self.envs, seed): for e, s in zip(self.envs, seed):
if hasattr(e, 'seed'): if hasattr(e, 'seed'):
e.seed(s) e.seed(s)
@ -198,8 +200,10 @@ class SubprocVectorEnv(BaseVectorEnv):
return np.stack([p.recv() for p in self.parent_remote]) return np.stack([p.recv() for p in self.parent_remote])
def seed(self, seed=None): def seed(self, seed=None):
if np.isscalar(seed) or seed is None: if np.isscalar(seed):
seed = [seed for _ in range(self.env_num)] seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
for p, s in zip(self.parent_remote, seed): for p, s in zip(self.parent_remote, seed):
p.send(['seed', s]) p.send(['seed', s])
for p in self.parent_remote: for p in self.parent_remote:
@ -272,8 +276,10 @@ class RayVectorEnv(BaseVectorEnv):
def seed(self, seed=None): def seed(self, seed=None):
if not hasattr(self.envs[0], 'seed'): if not hasattr(self.envs[0], 'seed'):
return return
if np.isscalar(seed) or seed is None: if np.isscalar(seed):
seed = [seed for _ in range(self.env_num)] seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)] result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
for r in result_obj: for r in result_obj:
ray.get(r) ray.get(r)

View File

@ -1,7 +1,9 @@
from tianshou.policy.base import BasePolicy from tianshou.policy.base import BasePolicy
from tianshou.policy.dqn import DQNPolicy from tianshou.policy.dqn import DQNPolicy
from tianshou.policy.policy_gradient import PGPolicy
__all__ = [ __all__ = [
'BasePolicy', 'BasePolicy',
'DQNPolicy', 'DQNPolicy',
'PGPolicy',
] ]

View File

@ -8,17 +8,17 @@ class BasePolicy(ABC):
super().__init__() super().__init__()
self.model = None self.model = None
@abstractmethod
def __call__(self, batch, hidden_state=None):
# return Batch(act=np.array(), state=None, ...)
pass
@abstractmethod
def learn(self, batch):
pass
def process_fn(self, batch, buffer, indice): def process_fn(self, batch, buffer, indice):
return batch return batch
@abstractmethod
def __call__(self, batch, state=None):
# return Batch(logits=..., act=np.array(), state=None, ...)
pass
@abstractmethod
def learn(self, batch, batch_size=None):
pass
def sync_weight(self): def sync_weight(self):
pass pass

View File

@ -10,14 +10,14 @@ from tianshou.policy import BasePolicy
class DQNPolicy(BasePolicy, nn.Module): class DQNPolicy(BasePolicy, nn.Module):
"""docstring for DQNPolicy""" """docstring for DQNPolicy"""
def __init__(self, model, optim, loss, def __init__(self, model, optim, loss_fn,
discount_factor=0.99, discount_factor=0.99,
estimation_step=1, estimation_step=1,
use_target_network=True): use_target_network=True):
super().__init__() super().__init__()
self.model = model self.model = model
self.optim = optim self.optim = optim
self.loss = loss self.loss_fn = loss_fn
self.eps = 0 self.eps = 0
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]' assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
self._gamma = discount_factor self._gamma = discount_factor
@ -28,20 +28,6 @@ class DQNPolicy(BasePolicy, nn.Module):
self.model_old = deepcopy(self.model) self.model_old = deepcopy(self.model)
self.model_old.eval() self.model_old.eval()
def __call__(self, batch, hidden_state=None,
model='model', input='obs', eps=None):
model = getattr(self, model)
obs = getattr(batch, input)
q, h = model(obs, hidden_state=hidden_state, info=batch.info)
act = q.max(dim=1)[1].detach().cpu().numpy()
# add eps to act
if eps is None:
eps = self.eps
for i in range(len(q)):
if np.random.rand() < eps:
act[i] = np.random.randint(q.shape[1])
return Batch(Q=q, act=act, state=h)
def set_eps(self, eps): def set_eps(self, eps):
self.eps = eps self.eps = eps
@ -70,12 +56,12 @@ class DQNPolicy(BasePolicy, nn.Module):
# target_Q = Q_old(s_, argmax(Q_new(s_, *))) # target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(buffer[terminal], input='obs_next', eps=0).act a = self(buffer[terminal], input='obs_next', eps=0).act
target_q = self( target_q = self(
buffer[terminal], model='model_old', input='obs_next').Q buffer[terminal], model='model_old', input='obs_next').logits
if isinstance(target_q, torch.Tensor): if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy() target_q = target_q.detach().cpu().numpy()
target_q = target_q[np.arange(len(a)), a] target_q = target_q[np.arange(len(a)), a]
else: else:
target_q = self(buffer[terminal], input='obs_next').Q target_q = self(buffer[terminal], input='obs_next').logits
if isinstance(target_q, torch.Tensor): if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy() target_q = target_q.detach().cpu().numpy()
target_q = target_q.max(axis=1) target_q = target_q.max(axis=1)
@ -84,14 +70,28 @@ class DQNPolicy(BasePolicy, nn.Module):
batch.update(returns=returns) batch.update(returns=returns)
return batch return batch
def learn(self, batch): def __call__(self, batch, state=None,
model='model', input='obs', eps=None):
model = getattr(self, model)
obs = getattr(batch, input)
q, h = model(obs, state=state, info=batch.info)
act = q.max(dim=1)[1].detach().cpu().numpy()
# add eps to act
if eps is None:
eps = self.eps
for i in range(len(q)):
if np.random.rand() < eps:
act[i] = np.random.randint(q.shape[1])
return Batch(logits=q, act=act, state=h)
def learn(self, batch, batch_size=None):
self.optim.zero_grad() self.optim.zero_grad()
q = self(batch).Q q = self(batch).logits
q = q[np.arange(len(q)), batch.act] q = q[np.arange(len(q)), batch.act]
r = batch.returns r = batch.returns
if isinstance(r, np.ndarray): if isinstance(r, np.ndarray):
r = torch.tensor(r, device=q.device, dtype=q.dtype) r = torch.tensor(r, device=q.device, dtype=q.dtype)
loss = self.loss(q, r) loss = self.loss_fn(q, r)
loss.backward() loss.backward()
self.optim.step() self.optim.step()
return loss.detach().cpu().numpy() return loss.detach().cpu().numpy()

View File

@ -0,0 +1,76 @@
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from torch.distributions import Categorical
from tianshou.data import Batch
from tianshou.policy import BasePolicy
class PGPolicy(BasePolicy, nn.Module):
"""docstring for PGPolicy"""
def __init__(self, model, optim, dist=Categorical,
discount_factor=0.99, normalized_reward=True):
super().__init__()
self.model = model
self.optim = optim
self.dist = dist
self._eps = np.finfo(np.float32).eps.item()
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
self._gamma = discount_factor
self._rew_norm = normalized_reward
def process_fn(self, batch, buffer, indice):
batch_size = len(batch.rew)
returns = self._vanilla_returns(batch, batch_size)
# returns = self._vectorized_returns(batch, batch_size)
returns = returns - returns.mean()
if self._rew_norm:
returns = returns / (returns.std() + self._eps)
batch.update(returns=returns)
return batch
def __call__(self, batch, state=None):
logits, h = self.model(batch.obs, state=state, info=batch.info)
logits = F.softmax(logits, dim=1)
dist = self.dist(logits)
act = dist.sample().detach().cpu().numpy()
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch, batch_size=None):
losses = []
for b in batch.split(batch_size):
self.optim.zero_grad()
dist = self(b).dist
a = torch.tensor(b.act, device=dist.logits.device)
r = torch.tensor(b.returns, device=dist.logits.device)
loss = -(dist.log_prob(a) * r).sum()
loss.backward()
self.optim.step()
losses.append(loss.detach().cpu().numpy())
return losses
def _vanilla_returns(self, batch, batch_size):
returns = batch.rew[:]
last = 0
for i in range(batch_size - 1, -1, -1):
if not batch.done[i]:
returns[i] += self._gamma * last
last = returns[i]
return returns
def _vectorized_returns(self, batch, batch_size):
# according to my tests, it is slower than vanilla
# import scipy.signal
convolve = np.convolve
# convolve = scipy.signal.convolve
rew = batch.rew[::-1]
gammas = self._gamma ** np.arange(batch_size)
c = convolve(rew, gammas)[:batch_size]
T = np.where(batch.done[::-1])[0]
d = np.zeros_like(rew)
d[T] += c[T] - rew[T]
d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
return (c - convolve(d, gammas)[:batch_size])[::-1]

View File

@ -11,7 +11,11 @@ class MovAvg(object):
def add(self, x): def add(self, x):
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy() x = x.detach().cpu().numpy()
if x != np.inf: if isinstance(x, list):
for _ in x:
if _ != np.inf:
self.cache.append(_)
elif x != np.inf:
self.cache.append(x) self.cache.append(x)
if self.size > 0 and len(self.cache) > self.size: if self.size > 0 and len(self.cache) > self.size:
self.cache = self.cache[-self.size:] self.cache = self.cache[-self.size:]