diff --git a/README.md b/README.md index bd57171..6267588 100644 --- a/README.md +++ b/README.md @@ -256,7 +256,7 @@ Tianshou is still under development. More algorithms and features are going to b - [ ] More examples on [mujoco, atari] benchmark - [ ] More algorithms - [ ] Prioritized replay buffer -- [ ] RNN support +- [x] RNN support - [ ] Imitation Learning - [ ] Multi-agent - [ ] Distributed training diff --git a/test/discrete/net.py b/test/discrete/net.py index d64c32f..ad1f4e7 100644 --- a/test/discrete/net.py +++ b/test/discrete/net.py @@ -53,33 +53,39 @@ class Critic(nn.Module): return logits -class DQN(nn.Module): - - def __init__(self, h, w, action_shape, device='cpu'): - super(DQN, self).__init__() +class Recurrent(nn.Module): + def __init__(self, layer_num, state_shape, action_shape, device='cpu'): + super().__init__() + self.state_shape = state_shape + self.action_shape = action_shape self.device = device + self.fc1 = nn.Linear(np.prod(state_shape), 128) + self.nn = nn.LSTM(input_size=128, hidden_size=128, + num_layers=layer_num, batch_first=True) + self.fc2 = nn.Linear(128, np.prod(action_shape)) - self.conv1 = nn.Conv2d(4, 16, kernel_size=5, stride=2) - self.bn1 = nn.BatchNorm2d(16) - self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2) - self.bn2 = nn.BatchNorm2d(32) - self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2) - self.bn3 = nn.BatchNorm2d(32) - - def conv2d_size_out(size, kernel_size=5, stride=2): - return (size - (kernel_size - 1) - 1) // stride + 1 - - convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w))) - convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h))) - linear_input_size = convw * convh * 32 - self.fc = nn.Linear(linear_input_size, 512) - self.head = nn.Linear(512, action_shape) - - def forward(self, x, state=None, info={}): - if not isinstance(x, torch.Tensor): - x = torch.tensor(x, device=self.device, dtype=torch.float) - x = F.relu(self.bn1(self.conv1(x))) - x = F.relu(self.bn2(self.conv2(x))) - x = F.relu(self.bn3(self.conv3(x))) - x = self.fc(x.reshape(x.size(0), -1)) - return self.head(x), state + def forward(self, s, state=None, info={}): + if not isinstance(s, torch.Tensor): + s = torch.tensor(s, device=self.device, dtype=torch.float) + # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) + # In short, the tensor's shape in training phase is longer than which + # in evaluation phase. + if len(s.shape) == 2: + bsz, dim = s.shape + length = 1 + else: + bsz, length, dim = s.shape + s = self.fc1(s.view([bsz * length, dim])) + s = s.view(bsz, length, -1) + self.nn.flatten_parameters() + if state is None: + s, (h, c) = self.nn(s) + else: + # we store the stack data in [bsz, len, ...] format + # but pytorch rnn needs [len, bsz, ...] + s, (h, c) = self.nn(s, (state['h'].transpose(0, 1).contiguous(), + state['c'].transpose(0, 1).contiguous())) + s = self.fc2(s)[:, -1] + # please ensure the first dim is batch size: [bsz, len, ...] + return s, {'h': h.transpose(0, 1).detach(), + 'c': c.transpose(0, 1).detach()} diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py new file mode 100644 index 0000000..76a6106 --- /dev/null +++ b/test/discrete/test_drqn.py @@ -0,0 +1,113 @@ +import gym +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.env import VectorEnv +from tianshou.policy import DQNPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer + +if __name__ == '__main__': + from net import Recurrent +else: # pytest + from test.discrete.net import Recurrent + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--stack-num', type=int, default=4) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--layer-num', type=int, default=3) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def test_drqn(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = VectorEnv( + [lambda: gym.make(args.task)for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = VectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Recurrent(args.layer_num, args.state_shape, + args.action_shape, args.device) + net = net.to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = DQNPolicy( + net, optim, args.gamma, args.n_step, + use_target_network=args.target_update_freq > 0, + target_update_freq=args.target_update_freq) + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer( + args.buffer_size, stack_num=args.stack_num)) + # the stack_num is for RNN training: sample framestack obs + test_collector = Collector(policy, test_envs) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size) + # log + writer = SummaryWriter(args.logdir + '/' + 'dqn') + + def stop_fn(x): + return x >= env.spec.reward_threshold + + def train_fn(x): + policy.set_eps(args.eps_train) + + def test_fn(x): + policy.set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, writer=writer) + + assert stop_fn(result['best_reward']) + train_collector.close() + test_collector.close() + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() + + +if __name__ == '__main__': + test_drqn(get_args()) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index efda12e..91fc9c3 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -41,9 +41,10 @@ class ReplayBuffer(object): array([ True, True, True, True]) """ - def __init__(self, size): + def __init__(self, size, stack_num=0): super().__init__() self._maxsize = size + self._stack = stack_num self.reset() def __len__(self): @@ -113,14 +114,28 @@ class ReplayBuffer(object): ]) return self[indice], indice + def _get_stack(self, indice, key): + if self.__dict__.get(key, None) is None: + return None + if self._stack == 0: + return self.__dict__[key][indice] + stack = [] + for i in range(self._stack): + stack = [self.__dict__[key][indice]] + stack + indice = indice - 1 + self.done[indice - 1].astype(np.int) + indice[indice == -1] = self._size - 1 + return np.stack(stack, axis=1) + def __getitem__(self, index): - """Return a data batch: self[index].""" + """Return a data batch: self[index]. If stack_num is set to be > 0, + return the stacked obs and obs_next with shape [batch, len, ...]. + """ return Batch( - obs=self.obs[index], + obs=self._get_stack(index, 'obs'), act=self.act[index], rew=self.rew[index], done=self.done[index], - obs_next=self.obs_next[index], + obs_next=self._get_stack(index, 'obs_next'), info=self.info[index] ) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 3572d6c..ca1bd19 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -150,12 +150,30 @@ class Collector(object): self.env.close() def _make_batch(self, data): + """Return [data].""" if isinstance(data, np.ndarray): return data[None] else: return np.array([data]) - def collect(self, n_step=0, n_episode=0, render=0): + def _reset_state(self, id): + """Reset self.state[id].""" + if self.state is None: + return + if isinstance(self.state, list): + self.state[id] = None + elif isinstance(self.state, dict): + for k in self.state: + if isinstance(self.state[k], list): + self.state[k][id] = None + elif isinstance(self.state[k], torch.Tensor) or \ + isinstance(self.state[k], np.ndarray): + self.state[k][id] = 0 + elif isinstance(self.state, torch.Tensor) or \ + isinstance(self.state, np.ndarray): + self.state[id] = 0 + + def collect(self, n_step=0, n_episode=0, render=None): """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. @@ -163,7 +181,7 @@ class Collector(object): environment). :type n_episode: int or list :param float render: the sleep time between rendering consecutive - frames. No rendering if it is ``0`` (default option). + frames, defaults to ``None`` (no rendering). .. note:: @@ -218,9 +236,10 @@ class Collector(object): self._act = result.act obs_next, self._rew, self._done, self._info = self.env.step( self._act if self._multi_env else self._act[0]) - if render > 0: + if render is not None: self.env.render() - time.sleep(render) + if render > 0: + time.sleep(render) self.length += 1 self.reward += self._rew if self._multi_env: @@ -253,16 +272,7 @@ class Collector(object): self.reward[i], self.length[i] = 0, 0 if self._cached_buf: self._cached_buf[i].reset() - if isinstance(self.state, list): - self.state[i] = None - elif self.state is not None: - if isinstance(self.state[i], dict): - self.state[i] = {} - else: - self.state[i] = self.state[i] * 0 - if isinstance(self.state, torch.Tensor): - # remove ref count in pytorch (?) - self.state = self.state.detach() + self._reset_state(i) if sum(self._done): obs_next = self.env.reset(np.where(self._done)[0]) if n_episode != 0: diff --git a/tianshou/policy/a2c.py b/tianshou/policy/a2c.py index 52bd543..9ee807f 100644 --- a/tianshou/policy/a2c.py +++ b/tianshou/policy/a2c.py @@ -27,7 +27,7 @@ class A2CPolicy(PGPolicy): dist_fn=torch.distributions.Categorical, discount_factor=0.99, vf_coef=.5, ent_coef=.01, max_grad_norm=None, **kwargs): - super().__init__(None, optim, dist_fn, discount_factor) + super().__init__(None, optim, dist_fn, discount_factor, **kwargs) self.actor = actor self.critic = critic self._w_vf = vf_coef diff --git a/tianshou/policy/ddpg.py b/tianshou/policy/ddpg.py index a7a9578..b07f64d 100644 --- a/tianshou/policy/ddpg.py +++ b/tianshou/policy/ddpg.py @@ -34,7 +34,7 @@ class DDPGPolicy(BasePolicy): tau=0.005, gamma=0.99, exploration_noise=0.1, action_range=None, reward_normalization=False, ignore_done=False, **kwargs): - super().__init__() + super().__init__(**kwargs) if actor is not None: self.actor, self.actor_old = actor, deepcopy(actor) self.actor_old.eval() diff --git a/tianshou/policy/dqn.py b/tianshou/policy/dqn.py index e8980a9..bd370bb 100644 --- a/tianshou/policy/dqn.py +++ b/tianshou/policy/dqn.py @@ -22,7 +22,7 @@ class DQNPolicy(BasePolicy): def __init__(self, model, optim, discount_factor=0.99, estimation_step=1, target_update_freq=0, **kwargs): - super().__init__() + super().__init__(**kwargs) self.model = model self.optim = optim self.eps = 0 diff --git a/tianshou/policy/pg.py b/tianshou/policy/pg.py index ca31738..47b397d 100644 --- a/tianshou/policy/pg.py +++ b/tianshou/policy/pg.py @@ -17,7 +17,7 @@ class PGPolicy(BasePolicy): def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, discount_factor=0.99, **kwargs): - super().__init__() + super().__init__(**kwargs) self.model = model self.optim = optim self.dist_fn = dist_fn diff --git a/tianshou/policy/ppo.py b/tianshou/policy/ppo.py index 1bca452..a75eddb 100644 --- a/tianshou/policy/ppo.py +++ b/tianshou/policy/ppo.py @@ -36,7 +36,7 @@ class PPOPolicy(PGPolicy): ent_coef=.0, action_range=None, **kwargs): - super().__init__(None, None, dist_fn, discount_factor) + super().__init__(None, None, dist_fn, discount_factor, **kwargs) self._max_grad_norm = max_grad_norm self._eps_clip = eps_clip self._w_vf = vf_coef diff --git a/tianshou/policy/sac.py b/tianshou/policy/sac.py index 7a28dc3..80cfe52 100644 --- a/tianshou/policy/sac.py +++ b/tianshou/policy/sac.py @@ -40,7 +40,8 @@ class SACPolicy(DDPGPolicy): alpha=0.2, action_range=None, reward_normalization=False, ignore_done=False, **kwargs): super().__init__(None, None, None, None, tau, gamma, 0, - action_range, reward_normalization, ignore_done) + action_range, reward_normalization, ignore_done, + **kwargs) self.actor, self.actor_optim = actor, actor_optim self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() diff --git a/tianshou/policy/td3.py b/tianshou/policy/td3.py index 5df68f6..4d593d2 100644 --- a/tianshou/policy/td3.py +++ b/tianshou/policy/td3.py @@ -46,7 +46,7 @@ class TD3Policy(DDPGPolicy): reward_normalization=False, ignore_done=False, **kwargs): super().__init__(actor, actor_optim, None, None, tau, gamma, exploration_noise, action_range, reward_normalization, - ignore_done) + ignore_done, **kwargs) self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() self.critic1_optim = critic1_optim diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 78ea860..2cfa5c7 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -53,9 +53,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, policy.train() if train_fn: train_fn(epoch) - with tqdm.tqdm( - total=step_per_epoch, desc=f'Epoch #{epoch}', - **tqdm_config) as t: + with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', + **tqdm_config) as t: while t.n < t.total: result = train_collector.collect(n_step=collect_per_step) data = {} diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index e633c5b..7f58431 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -58,9 +58,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, policy.train() if train_fn: train_fn(epoch) - with tqdm.tqdm( - total=step_per_epoch, desc=f'Epoch #{epoch}', - **tqdm_config) as t: + with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', + **tqdm_config) as t: while t.n < t.total: result = train_collector.collect(n_episode=collect_per_step) data = {}