maybe finished rnn?
This commit is contained in:
parent
d9d2763dad
commit
86572c66d4
@ -256,7 +256,7 @@ Tianshou is still under development. More algorithms and features are going to b
|
|||||||
- [ ] More examples on [mujoco, atari] benchmark
|
- [ ] More examples on [mujoco, atari] benchmark
|
||||||
- [ ] More algorithms
|
- [ ] More algorithms
|
||||||
- [ ] Prioritized replay buffer
|
- [ ] Prioritized replay buffer
|
||||||
- [ ] RNN support
|
- [x] RNN support
|
||||||
- [ ] Imitation Learning
|
- [ ] Imitation Learning
|
||||||
- [ ] Multi-agent
|
- [ ] Multi-agent
|
||||||
- [ ] Distributed training
|
- [ ] Distributed training
|
||||||
|
@ -53,33 +53,39 @@ class Critic(nn.Module):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
class DQN(nn.Module):
|
class Recurrent(nn.Module):
|
||||||
|
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
||||||
def __init__(self, h, w, action_shape, device='cpu'):
|
super().__init__()
|
||||||
super(DQN, self).__init__()
|
self.state_shape = state_shape
|
||||||
|
self.action_shape = action_shape
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.fc1 = nn.Linear(np.prod(state_shape), 128)
|
||||||
|
self.nn = nn.LSTM(input_size=128, hidden_size=128,
|
||||||
|
num_layers=layer_num, batch_first=True)
|
||||||
|
self.fc2 = nn.Linear(128, np.prod(action_shape))
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(4, 16, kernel_size=5, stride=2)
|
def forward(self, s, state=None, info={}):
|
||||||
self.bn1 = nn.BatchNorm2d(16)
|
if not isinstance(s, torch.Tensor):
|
||||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
|
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||||
self.bn2 = nn.BatchNorm2d(32)
|
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||||
self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
|
# In short, the tensor's shape in training phase is longer than which
|
||||||
self.bn3 = nn.BatchNorm2d(32)
|
# in evaluation phase.
|
||||||
|
if len(s.shape) == 2:
|
||||||
def conv2d_size_out(size, kernel_size=5, stride=2):
|
bsz, dim = s.shape
|
||||||
return (size - (kernel_size - 1) - 1) // stride + 1
|
length = 1
|
||||||
|
else:
|
||||||
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
|
bsz, length, dim = s.shape
|
||||||
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
|
s = self.fc1(s.view([bsz * length, dim]))
|
||||||
linear_input_size = convw * convh * 32
|
s = s.view(bsz, length, -1)
|
||||||
self.fc = nn.Linear(linear_input_size, 512)
|
self.nn.flatten_parameters()
|
||||||
self.head = nn.Linear(512, action_shape)
|
if state is None:
|
||||||
|
s, (h, c) = self.nn(s)
|
||||||
def forward(self, x, state=None, info={}):
|
else:
|
||||||
if not isinstance(x, torch.Tensor):
|
# we store the stack data in [bsz, len, ...] format
|
||||||
x = torch.tensor(x, device=self.device, dtype=torch.float)
|
# but pytorch rnn needs [len, bsz, ...]
|
||||||
x = F.relu(self.bn1(self.conv1(x)))
|
s, (h, c) = self.nn(s, (state['h'].transpose(0, 1).contiguous(),
|
||||||
x = F.relu(self.bn2(self.conv2(x)))
|
state['c'].transpose(0, 1).contiguous()))
|
||||||
x = F.relu(self.bn3(self.conv3(x)))
|
s = self.fc2(s)[:, -1]
|
||||||
x = self.fc(x.reshape(x.size(0), -1))
|
# please ensure the first dim is batch size: [bsz, len, ...]
|
||||||
return self.head(x), state
|
return s, {'h': h.transpose(0, 1).detach(),
|
||||||
|
'c': c.transpose(0, 1).detach()}
|
||||||
|
113
test/discrete/test_drqn.py
Normal file
113
test/discrete/test_drqn.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import gym
|
||||||
|
import torch
|
||||||
|
import pprint
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.env import VectorEnv
|
||||||
|
from tianshou.policy import DQNPolicy
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from net import Recurrent
|
||||||
|
else: # pytest
|
||||||
|
from test.discrete.net import Recurrent
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
parser.add_argument('--stack-num', type=int, default=4)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument('--layer-num', type=int, default=3)
|
||||||
|
parser.add_argument('--training-num', type=int, default=8)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str,
|
||||||
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_drqn(args=get_args()):
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
# you can also use tianshou.env.SubprocVectorEnv
|
||||||
|
train_envs = VectorEnv(
|
||||||
|
[lambda: gym.make(args.task)for _ in range(args.training_num)])
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = VectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
net = Recurrent(args.layer_num, args.state_shape,
|
||||||
|
args.action_shape, args.device)
|
||||||
|
net = net.to(args.device)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
policy = DQNPolicy(
|
||||||
|
net, optim, args.gamma, args.n_step,
|
||||||
|
use_target_network=args.target_update_freq > 0,
|
||||||
|
target_update_freq=args.target_update_freq)
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(
|
||||||
|
policy, train_envs, ReplayBuffer(
|
||||||
|
args.buffer_size, stack_num=args.stack_num))
|
||||||
|
# the stack_num is for RNN training: sample framestack obs
|
||||||
|
test_collector = Collector(policy, test_envs)
|
||||||
|
# policy.set_eps(1)
|
||||||
|
train_collector.collect(n_step=args.batch_size)
|
||||||
|
# log
|
||||||
|
writer = SummaryWriter(args.logdir + '/' + 'dqn')
|
||||||
|
|
||||||
|
def stop_fn(x):
|
||||||
|
return x >= env.spec.reward_threshold
|
||||||
|
|
||||||
|
def train_fn(x):
|
||||||
|
policy.set_eps(args.eps_train)
|
||||||
|
|
||||||
|
def test_fn(x):
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
policy, train_collector, test_collector, args.epoch,
|
||||||
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn, writer=writer)
|
||||||
|
|
||||||
|
assert stop_fn(result['best_reward'])
|
||||||
|
train_collector.close()
|
||||||
|
test_collector.close()
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
# Let's watch its performance!
|
||||||
|
env = gym.make(args.task)
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
|
collector.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_drqn(get_args())
|
@ -41,9 +41,10 @@ class ReplayBuffer(object):
|
|||||||
array([ True, True, True, True])
|
array([ True, True, True, True])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, size):
|
def __init__(self, size, stack_num=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._maxsize = size
|
self._maxsize = size
|
||||||
|
self._stack = stack_num
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -113,14 +114,28 @@ class ReplayBuffer(object):
|
|||||||
])
|
])
|
||||||
return self[indice], indice
|
return self[indice], indice
|
||||||
|
|
||||||
|
def _get_stack(self, indice, key):
|
||||||
|
if self.__dict__.get(key, None) is None:
|
||||||
|
return None
|
||||||
|
if self._stack == 0:
|
||||||
|
return self.__dict__[key][indice]
|
||||||
|
stack = []
|
||||||
|
for i in range(self._stack):
|
||||||
|
stack = [self.__dict__[key][indice]] + stack
|
||||||
|
indice = indice - 1 + self.done[indice - 1].astype(np.int)
|
||||||
|
indice[indice == -1] = self._size - 1
|
||||||
|
return np.stack(stack, axis=1)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
"""Return a data batch: self[index]."""
|
"""Return a data batch: self[index]. If stack_num is set to be > 0,
|
||||||
|
return the stacked obs and obs_next with shape [batch, len, ...].
|
||||||
|
"""
|
||||||
return Batch(
|
return Batch(
|
||||||
obs=self.obs[index],
|
obs=self._get_stack(index, 'obs'),
|
||||||
act=self.act[index],
|
act=self.act[index],
|
||||||
rew=self.rew[index],
|
rew=self.rew[index],
|
||||||
done=self.done[index],
|
done=self.done[index],
|
||||||
obs_next=self.obs_next[index],
|
obs_next=self._get_stack(index, 'obs_next'),
|
||||||
info=self.info[index]
|
info=self.info[index]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -150,12 +150,30 @@ class Collector(object):
|
|||||||
self.env.close()
|
self.env.close()
|
||||||
|
|
||||||
def _make_batch(self, data):
|
def _make_batch(self, data):
|
||||||
|
"""Return [data]."""
|
||||||
if isinstance(data, np.ndarray):
|
if isinstance(data, np.ndarray):
|
||||||
return data[None]
|
return data[None]
|
||||||
else:
|
else:
|
||||||
return np.array([data])
|
return np.array([data])
|
||||||
|
|
||||||
def collect(self, n_step=0, n_episode=0, render=0):
|
def _reset_state(self, id):
|
||||||
|
"""Reset self.state[id]."""
|
||||||
|
if self.state is None:
|
||||||
|
return
|
||||||
|
if isinstance(self.state, list):
|
||||||
|
self.state[id] = None
|
||||||
|
elif isinstance(self.state, dict):
|
||||||
|
for k in self.state:
|
||||||
|
if isinstance(self.state[k], list):
|
||||||
|
self.state[k][id] = None
|
||||||
|
elif isinstance(self.state[k], torch.Tensor) or \
|
||||||
|
isinstance(self.state[k], np.ndarray):
|
||||||
|
self.state[k][id] = 0
|
||||||
|
elif isinstance(self.state, torch.Tensor) or \
|
||||||
|
isinstance(self.state, np.ndarray):
|
||||||
|
self.state[id] = 0
|
||||||
|
|
||||||
|
def collect(self, n_step=0, n_episode=0, render=None):
|
||||||
"""Collect a specified number of step or episode.
|
"""Collect a specified number of step or episode.
|
||||||
|
|
||||||
:param int n_step: how many steps you want to collect.
|
:param int n_step: how many steps you want to collect.
|
||||||
@ -163,7 +181,7 @@ class Collector(object):
|
|||||||
environment).
|
environment).
|
||||||
:type n_episode: int or list
|
:type n_episode: int or list
|
||||||
:param float render: the sleep time between rendering consecutive
|
:param float render: the sleep time between rendering consecutive
|
||||||
frames. No rendering if it is ``0`` (default option).
|
frames, defaults to ``None`` (no rendering).
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@ -218,9 +236,10 @@ class Collector(object):
|
|||||||
self._act = result.act
|
self._act = result.act
|
||||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||||
self._act if self._multi_env else self._act[0])
|
self._act if self._multi_env else self._act[0])
|
||||||
if render > 0:
|
if render is not None:
|
||||||
self.env.render()
|
self.env.render()
|
||||||
time.sleep(render)
|
if render > 0:
|
||||||
|
time.sleep(render)
|
||||||
self.length += 1
|
self.length += 1
|
||||||
self.reward += self._rew
|
self.reward += self._rew
|
||||||
if self._multi_env:
|
if self._multi_env:
|
||||||
@ -253,16 +272,7 @@ class Collector(object):
|
|||||||
self.reward[i], self.length[i] = 0, 0
|
self.reward[i], self.length[i] = 0, 0
|
||||||
if self._cached_buf:
|
if self._cached_buf:
|
||||||
self._cached_buf[i].reset()
|
self._cached_buf[i].reset()
|
||||||
if isinstance(self.state, list):
|
self._reset_state(i)
|
||||||
self.state[i] = None
|
|
||||||
elif self.state is not None:
|
|
||||||
if isinstance(self.state[i], dict):
|
|
||||||
self.state[i] = {}
|
|
||||||
else:
|
|
||||||
self.state[i] = self.state[i] * 0
|
|
||||||
if isinstance(self.state, torch.Tensor):
|
|
||||||
# remove ref count in pytorch (?)
|
|
||||||
self.state = self.state.detach()
|
|
||||||
if sum(self._done):
|
if sum(self._done):
|
||||||
obs_next = self.env.reset(np.where(self._done)[0])
|
obs_next = self.env.reset(np.where(self._done)[0])
|
||||||
if n_episode != 0:
|
if n_episode != 0:
|
||||||
|
@ -27,7 +27,7 @@ class A2CPolicy(PGPolicy):
|
|||||||
dist_fn=torch.distributions.Categorical,
|
dist_fn=torch.distributions.Categorical,
|
||||||
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
|
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
|
||||||
max_grad_norm=None, **kwargs):
|
max_grad_norm=None, **kwargs):
|
||||||
super().__init__(None, optim, dist_fn, discount_factor)
|
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
|
||||||
self.actor = actor
|
self.actor = actor
|
||||||
self.critic = critic
|
self.critic = critic
|
||||||
self._w_vf = vf_coef
|
self._w_vf = vf_coef
|
||||||
|
@ -34,7 +34,7 @@ class DDPGPolicy(BasePolicy):
|
|||||||
tau=0.005, gamma=0.99, exploration_noise=0.1,
|
tau=0.005, gamma=0.99, exploration_noise=0.1,
|
||||||
action_range=None, reward_normalization=False,
|
action_range=None, reward_normalization=False,
|
||||||
ignore_done=False, **kwargs):
|
ignore_done=False, **kwargs):
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
if actor is not None:
|
if actor is not None:
|
||||||
self.actor, self.actor_old = actor, deepcopy(actor)
|
self.actor, self.actor_old = actor, deepcopy(actor)
|
||||||
self.actor_old.eval()
|
self.actor_old.eval()
|
||||||
|
@ -22,7 +22,7 @@ class DQNPolicy(BasePolicy):
|
|||||||
|
|
||||||
def __init__(self, model, optim, discount_factor=0.99,
|
def __init__(self, model, optim, discount_factor=0.99,
|
||||||
estimation_step=1, target_update_freq=0, **kwargs):
|
estimation_step=1, target_update_freq=0, **kwargs):
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
self.eps = 0
|
self.eps = 0
|
||||||
|
@ -17,7 +17,7 @@ class PGPolicy(BasePolicy):
|
|||||||
|
|
||||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||||
discount_factor=0.99, **kwargs):
|
discount_factor=0.99, **kwargs):
|
||||||
super().__init__()
|
super().__init__(**kwargs)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
self.dist_fn = dist_fn
|
self.dist_fn = dist_fn
|
||||||
|
@ -36,7 +36,7 @@ class PPOPolicy(PGPolicy):
|
|||||||
ent_coef=.0,
|
ent_coef=.0,
|
||||||
action_range=None,
|
action_range=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(None, None, dist_fn, discount_factor)
|
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||||
self._max_grad_norm = max_grad_norm
|
self._max_grad_norm = max_grad_norm
|
||||||
self._eps_clip = eps_clip
|
self._eps_clip = eps_clip
|
||||||
self._w_vf = vf_coef
|
self._w_vf = vf_coef
|
||||||
|
@ -40,7 +40,8 @@ class SACPolicy(DDPGPolicy):
|
|||||||
alpha=0.2, action_range=None, reward_normalization=False,
|
alpha=0.2, action_range=None, reward_normalization=False,
|
||||||
ignore_done=False, **kwargs):
|
ignore_done=False, **kwargs):
|
||||||
super().__init__(None, None, None, None, tau, gamma, 0,
|
super().__init__(None, None, None, None, tau, gamma, 0,
|
||||||
action_range, reward_normalization, ignore_done)
|
action_range, reward_normalization, ignore_done,
|
||||||
|
**kwargs)
|
||||||
self.actor, self.actor_optim = actor, actor_optim
|
self.actor, self.actor_optim = actor, actor_optim
|
||||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||||
self.critic1_old.eval()
|
self.critic1_old.eval()
|
||||||
|
@ -46,7 +46,7 @@ class TD3Policy(DDPGPolicy):
|
|||||||
reward_normalization=False, ignore_done=False, **kwargs):
|
reward_normalization=False, ignore_done=False, **kwargs):
|
||||||
super().__init__(actor, actor_optim, None, None, tau, gamma,
|
super().__init__(actor, actor_optim, None, None, tau, gamma,
|
||||||
exploration_noise, action_range, reward_normalization,
|
exploration_noise, action_range, reward_normalization,
|
||||||
ignore_done)
|
ignore_done, **kwargs)
|
||||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||||
self.critic1_old.eval()
|
self.critic1_old.eval()
|
||||||
self.critic1_optim = critic1_optim
|
self.critic1_optim = critic1_optim
|
||||||
|
@ -53,9 +53,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
policy.train()
|
policy.train()
|
||||||
if train_fn:
|
if train_fn:
|
||||||
train_fn(epoch)
|
train_fn(epoch)
|
||||||
with tqdm.tqdm(
|
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||||
total=step_per_epoch, desc=f'Epoch #{epoch}',
|
**tqdm_config) as t:
|
||||||
**tqdm_config) as t:
|
|
||||||
while t.n < t.total:
|
while t.n < t.total:
|
||||||
result = train_collector.collect(n_step=collect_per_step)
|
result = train_collector.collect(n_step=collect_per_step)
|
||||||
data = {}
|
data = {}
|
||||||
|
@ -58,9 +58,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
policy.train()
|
policy.train()
|
||||||
if train_fn:
|
if train_fn:
|
||||||
train_fn(epoch)
|
train_fn(epoch)
|
||||||
with tqdm.tqdm(
|
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||||
total=step_per_epoch, desc=f'Epoch #{epoch}',
|
**tqdm_config) as t:
|
||||||
**tqdm_config) as t:
|
|
||||||
while t.n < t.total:
|
while t.n < t.total:
|
||||||
result = train_collector.collect(n_episode=collect_per_step)
|
result = train_collector.collect(n_episode=collect_per_step)
|
||||||
data = {}
|
data = {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user