maybe finished rnn?

This commit is contained in:
Trinkle23897 2020-04-08 21:13:15 +08:00
parent d9d2763dad
commit 86572c66d4
14 changed files with 203 additions and 60 deletions

View File

@ -256,7 +256,7 @@ Tianshou is still under development. More algorithms and features are going to b
- [ ] More examples on [mujoco, atari] benchmark
- [ ] More algorithms
- [ ] Prioritized replay buffer
- [ ] RNN support
- [x] RNN support
- [ ] Imitation Learning
- [ ] Multi-agent
- [ ] Distributed training

View File

@ -53,33 +53,39 @@ class Critic(nn.Module):
return logits
class DQN(nn.Module):
def __init__(self, h, w, action_shape, device='cpu'):
super(DQN, self).__init__()
class Recurrent(nn.Module):
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
super().__init__()
self.state_shape = state_shape
self.action_shape = action_shape
self.device = device
self.fc1 = nn.Linear(np.prod(state_shape), 128)
self.nn = nn.LSTM(input_size=128, hidden_size=128,
num_layers=layer_num, batch_first=True)
self.fc2 = nn.Linear(128, np.prod(action_shape))
self.conv1 = nn.Conv2d(4, 16, kernel_size=5, stride=2)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
self.bn3 = nn.BatchNorm2d(32)
def conv2d_size_out(size, kernel_size=5, stride=2):
return (size - (kernel_size - 1) - 1) // stride + 1
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
linear_input_size = convw * convh * 32
self.fc = nn.Linear(linear_input_size, 512)
self.head = nn.Linear(512, action_shape)
def forward(self, x, state=None, info={}):
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = self.fc(x.reshape(x.size(0), -1))
return self.head(x), state
def forward(self, s, state=None, info={}):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
if len(s.shape) == 2:
bsz, dim = s.shape
length = 1
else:
bsz, length, dim = s.shape
s = self.fc1(s.view([bsz * length, dim]))
s = s.view(bsz, length, -1)
self.nn.flatten_parameters()
if state is None:
s, (h, c) = self.nn(s)
else:
# we store the stack data in [bsz, len, ...] format
# but pytorch rnn needs [len, bsz, ...]
s, (h, c) = self.nn(s, (state['h'].transpose(0, 1).contiguous(),
state['c'].transpose(0, 1).contiguous()))
s = self.fc2(s)[:, -1]
# please ensure the first dim is batch size: [bsz, len, ...]
return s, {'h': h.transpose(0, 1).detach(),
'c': c.transpose(0, 1).detach()}

113
test/discrete/test_drqn.py Normal file
View File

@ -0,0 +1,113 @@
import gym
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.env import VectorEnv
from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
if __name__ == '__main__':
from net import Recurrent
else: # pytest
from test.discrete.net import Recurrent
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--stack-num', type=int, default=4)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def test_drqn(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = VectorEnv(
[lambda: gym.make(args.task)for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Recurrent(args.layer_num, args.state_shape,
args.action_shape, args.device)
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
net, optim, args.gamma, args.n_step,
use_target_network=args.target_update_freq > 0,
target_update_freq=args.target_update_freq)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(
args.buffer_size, stack_num=args.stack_num))
# the stack_num is for RNN training: sample framestack obs
test_collector = Collector(policy, test_envs)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size)
# log
writer = SummaryWriter(args.logdir + '/' + 'dqn')
def stop_fn(x):
return x >= env.spec.reward_threshold
def train_fn(x):
policy.set_eps(args.eps_train)
def test_fn(x):
policy.set_eps(args.eps_test)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
if __name__ == '__main__':
test_drqn(get_args())

View File

@ -41,9 +41,10 @@ class ReplayBuffer(object):
array([ True, True, True, True])
"""
def __init__(self, size):
def __init__(self, size, stack_num=0):
super().__init__()
self._maxsize = size
self._stack = stack_num
self.reset()
def __len__(self):
@ -113,14 +114,28 @@ class ReplayBuffer(object):
])
return self[indice], indice
def _get_stack(self, indice, key):
if self.__dict__.get(key, None) is None:
return None
if self._stack == 0:
return self.__dict__[key][indice]
stack = []
for i in range(self._stack):
stack = [self.__dict__[key][indice]] + stack
indice = indice - 1 + self.done[indice - 1].astype(np.int)
indice[indice == -1] = self._size - 1
return np.stack(stack, axis=1)
def __getitem__(self, index):
"""Return a data batch: self[index]."""
"""Return a data batch: self[index]. If stack_num is set to be > 0,
return the stacked obs and obs_next with shape [batch, len, ...].
"""
return Batch(
obs=self.obs[index],
obs=self._get_stack(index, 'obs'),
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self.obs_next[index],
obs_next=self._get_stack(index, 'obs_next'),
info=self.info[index]
)

View File

@ -150,12 +150,30 @@ class Collector(object):
self.env.close()
def _make_batch(self, data):
"""Return [data]."""
if isinstance(data, np.ndarray):
return data[None]
else:
return np.array([data])
def collect(self, n_step=0, n_episode=0, render=0):
def _reset_state(self, id):
"""Reset self.state[id]."""
if self.state is None:
return
if isinstance(self.state, list):
self.state[id] = None
elif isinstance(self.state, dict):
for k in self.state:
if isinstance(self.state[k], list):
self.state[k][id] = None
elif isinstance(self.state[k], torch.Tensor) or \
isinstance(self.state[k], np.ndarray):
self.state[k][id] = 0
elif isinstance(self.state, torch.Tensor) or \
isinstance(self.state, np.ndarray):
self.state[id] = 0
def collect(self, n_step=0, n_episode=0, render=None):
"""Collect a specified number of step or episode.
:param int n_step: how many steps you want to collect.
@ -163,7 +181,7 @@ class Collector(object):
environment).
:type n_episode: int or list
:param float render: the sleep time between rendering consecutive
frames. No rendering if it is ``0`` (default option).
frames, defaults to ``None`` (no rendering).
.. note::
@ -218,9 +236,10 @@ class Collector(object):
self._act = result.act
obs_next, self._rew, self._done, self._info = self.env.step(
self._act if self._multi_env else self._act[0])
if render > 0:
if render is not None:
self.env.render()
time.sleep(render)
if render > 0:
time.sleep(render)
self.length += 1
self.reward += self._rew
if self._multi_env:
@ -253,16 +272,7 @@ class Collector(object):
self.reward[i], self.length[i] = 0, 0
if self._cached_buf:
self._cached_buf[i].reset()
if isinstance(self.state, list):
self.state[i] = None
elif self.state is not None:
if isinstance(self.state[i], dict):
self.state[i] = {}
else:
self.state[i] = self.state[i] * 0
if isinstance(self.state, torch.Tensor):
# remove ref count in pytorch (?)
self.state = self.state.detach()
self._reset_state(i)
if sum(self._done):
obs_next = self.env.reset(np.where(self._done)[0])
if n_episode != 0:

View File

@ -27,7 +27,7 @@ class A2CPolicy(PGPolicy):
dist_fn=torch.distributions.Categorical,
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
max_grad_norm=None, **kwargs):
super().__init__(None, optim, dist_fn, discount_factor)
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self.actor = actor
self.critic = critic
self._w_vf = vf_coef

View File

@ -34,7 +34,7 @@ class DDPGPolicy(BasePolicy):
tau=0.005, gamma=0.99, exploration_noise=0.1,
action_range=None, reward_normalization=False,
ignore_done=False, **kwargs):
super().__init__()
super().__init__(**kwargs)
if actor is not None:
self.actor, self.actor_old = actor, deepcopy(actor)
self.actor_old.eval()

View File

@ -22,7 +22,7 @@ class DQNPolicy(BasePolicy):
def __init__(self, model, optim, discount_factor=0.99,
estimation_step=1, target_update_freq=0, **kwargs):
super().__init__()
super().__init__(**kwargs)
self.model = model
self.optim = optim
self.eps = 0

View File

@ -17,7 +17,7 @@ class PGPolicy(BasePolicy):
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
discount_factor=0.99, **kwargs):
super().__init__()
super().__init__(**kwargs)
self.model = model
self.optim = optim
self.dist_fn = dist_fn

View File

@ -36,7 +36,7 @@ class PPOPolicy(PGPolicy):
ent_coef=.0,
action_range=None,
**kwargs):
super().__init__(None, None, dist_fn, discount_factor)
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
self._max_grad_norm = max_grad_norm
self._eps_clip = eps_clip
self._w_vf = vf_coef

View File

@ -40,7 +40,8 @@ class SACPolicy(DDPGPolicy):
alpha=0.2, action_range=None, reward_normalization=False,
ignore_done=False, **kwargs):
super().__init__(None, None, None, None, tau, gamma, 0,
action_range, reward_normalization, ignore_done)
action_range, reward_normalization, ignore_done,
**kwargs)
self.actor, self.actor_optim = actor, actor_optim
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval()

View File

@ -46,7 +46,7 @@ class TD3Policy(DDPGPolicy):
reward_normalization=False, ignore_done=False, **kwargs):
super().__init__(actor, actor_optim, None, None, tau, gamma,
exploration_noise, action_range, reward_normalization,
ignore_done)
ignore_done, **kwargs)
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval()
self.critic1_optim = critic1_optim

View File

@ -53,9 +53,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(
total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
result = train_collector.collect(n_step=collect_per_step)
data = {}

View File

@ -58,9 +58,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(
total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
result = train_collector.collect(n_episode=collect_per_step)
data = {}