diff --git a/.gitignore b/.gitignore index 6cc5860..4ceb487 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,4 @@ dmypy.json # customize flake8.sh +log/ diff --git a/test/test_dqn.py b/test/test_dqn.py new file mode 100644 index 0000000..59bda18 --- /dev/null +++ b/test/test_dqn.py @@ -0,0 +1,136 @@ +import gym +import tqdm +import torch +import argparse +import numpy as np +from torch import nn +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import DQNPolicy +from tianshou.env import SubprocVectorEnv +from tianshou.utils import tqdm_config, MovAvg +from tianshou.data import Collector, ReplayBuffer + + +class Net(nn.Module): + def __init__(self, layer_num, state_shape, action_shape, device): + super().__init__() + self.device = device + self.model = [ + nn.Linear(np.prod(state_shape), 128), + nn.ReLU(inplace=True)] + for i in range(layer_num): + self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] + self.model += [nn.Linear(128, np.prod(action_shape))] + self.model = nn.Sequential(*self.model) + + def forward(self, s, **kwargs): + if not isinstance(s, torch.Tensor): + s = torch.Tensor(s) + s = s.to(self.device) + batch = s.shape[0] + q = self.model(s.view(batch, -1)) + return q, None + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--n-step', type=int, default=1) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=320) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--layer-num', type=int, default=3) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=20) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def test_dqn(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)], + reset_after_done=True) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)], + reset_after_done=False) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) + net = net.to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + loss = nn.MSELoss() + policy = DQNPolicy(net, optim, loss, args.gamma, args.n_step) + # collector + training_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size)) + test_collector = Collector( + policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num) + training_collector.collect(n_step=args.batch_size) + # log + stat_loss = MovAvg() + global_step = 0 + writer = SummaryWriter(args.logdir) + best_epoch = -1 + best_reward = -1e10 + for epoch in range(args.epoch): + desc = f"Epoch #{epoch + 1}" + # train + policy.train() + policy.sync_weight() + policy.set_eps(args.eps_train) + with tqdm.trange( + 0, args.step_per_epoch, desc=desc, **tqdm_config) as t: + for _ in t: + training_collector.collect(n_step=args.collect_per_step) + global_step += 1 + result = training_collector.stat() + loss = policy.learn(training_collector.sample(args.batch_size)) + stat_loss.add(loss) + writer.add_scalar( + 'reward', result['reward'], global_step=global_step) + writer.add_scalar( + 'length', result['length'], global_step=global_step) + writer.add_scalar( + 'loss', stat_loss.get(), global_step=global_step) + t.set_postfix(loss=f'{stat_loss.get():.6f}', + reward=f'{result["reward"]:.6f}', + length=f'{result["length"]:.6f}') + # eval + test_collector.reset_env() + test_collector.reset_buffer() + policy.eval() + policy.set_eps(args.eps_test) + test_collector.collect(n_episode=args.test_num) + result = test_collector.stat() + if best_reward < result['reward']: + best_reward = result['reward'] + best_epoch = epoch + print(f'Epoch #{epoch + 1} reward: {result["reward"]:.6f}, ' + f'best_reward: {best_reward:.6f} in #{best_epoch}') + if args.task == 'CartPole-v0' and best_reward >= 200: + break + assert best_reward >= 200 + return best_reward + + +if __name__ == '__main__': + test_dqn(get_args()) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 80fb6f7..1f2a48f 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -65,6 +65,16 @@ class ReplayBuffer(object): info=self.info[indice] ), indice + def __getitem__(self, index): + return Batch( + obs=self.obs[index], + act=self.act[index], + rew=self.rew[index], + done=self.done[index], + obs_next=self.obs_next[index], + info=self.info[index] + ) + class PrioritizedReplayBuffer(ReplayBuffer): """docstring for PrioritizedReplayBuffer""" diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 1c7b448..d09da86 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -10,7 +10,7 @@ from tianshou.utils import MovAvg class Collector(object): """docstring for Collector""" - def __init__(self, policy, env, buffer, contiguous=True): + def __init__(self, policy, env, buffer, stat_size=100): super().__init__() self.env = env self.env_num = 1 @@ -19,27 +19,28 @@ class Collector(object): self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) self._multi_buf = False # buf is a list - # need multiple cache buffers only if contiguous in one buffer + # need multiple cache buffers only if storing in one buffer self._cached_buf = [] if self._multi_env: self.env_num = len(env) if isinstance(self.buffer, list): assert len(self.buffer) == self.env_num,\ - '# of data buffer does not match the # of input env.' + 'The number of data buffer does not match the number of '\ + 'input env.' self._multi_buf = True - elif isinstance(self.buffer, ReplayBuffer) and contiguous: + elif isinstance(self.buffer, ReplayBuffer): self._cached_buf = [ deepcopy(buffer) for _ in range(self.env_num)] else: raise TypeError('The buffer in data collector is invalid!') self.reset_env() - self.clear_buffer() - # state over batch is either a list, an np.ndarray, or torch.Tensor + self.reset_buffer() + # state over batch is either a list, an np.ndarray, or a torch.Tensor self.state = None - self.stat_reward = MovAvg() - self.stat_length = MovAvg() + self.stat_reward = MovAvg(stat_size) + self.stat_length = MovAvg(stat_size) - def clear_buffer(self): + def reset_buffer(self): if self._multi_buf: for b in self.buffer: b.reset() @@ -57,6 +58,18 @@ class Collector(object): for b in self._cached_buf: b.reset() + def seed(self, seed=None): + if hasattr(self.env, 'seed'): + self.env.seed(seed) + + def render(self): + if hasattr(self.env, 'render'): + self.env.render() + + def close(self): + if hasattr(self.env, 'close'): + self.env.close() + def _make_batch(data): if isinstance(data, np.ndarray): return data[None] @@ -66,9 +79,10 @@ class Collector(object): def collect(self, n_step=0, n_episode=0): assert sum([(n_step > 0), (n_episode > 0)]) == 1,\ "One and only one collection number specification permitted!" - cur_step, cur_episode = 0, 0 + cur_step = 0 + cur_episode = np.zeros(self.env_num) if self._multi_env else 0 while True: - if self.multi_env: + if self._multi_env: batch_data = Batch( obs=self._obs, act=self._act, rew=self._rew, done=self._done, obs_next=None, info=self._info) @@ -78,8 +92,9 @@ class Collector(object): act=self._make_batch(self._act), rew=self._make_batch(self._rew), done=self._make_batch(self._done), - obs_next=None, info=self._make_batch(self._info)) - result = self.policy.act(batch_data, self.state) + obs_next=None, + info=self._make_batch(self._info)) + result = self.policy(batch_data, self.state) self.state = result.state if hasattr(result, 'state') else None self._act = result.act obs_next, self._rew, self._done, self._info = self.env.step( @@ -88,6 +103,9 @@ class Collector(object): self.reward += self._rew if self._multi_env: for i in range(self.env_num): + if not self.env.is_reset_after_done()\ + and cur_episode[i] > 0: + continue data = { 'obs': self._obs[i], 'act': self._act[i], 'rew': self._rew[i], 'done': self._done[i], @@ -101,7 +119,7 @@ class Collector(object): self.buffer.add(**data) cur_step += 1 if self._done[i]: - cur_episode += 1 + cur_episode[i] += 1 self.stat_reward.add(self.reward[i]) self.stat_length.add(self.length[i]) self.reward[i], self.length[i] = 0, 0 @@ -111,12 +129,12 @@ class Collector(object): self._cached_buf[i].reset() if isinstance(self.state, list): self.state[i] = None - else: + elif self.state is not None: self.state[i] = self.state[i] * 0 if isinstance(self.state, torch.Tensor): - # remove ref in torch (?) + # remove ref count in pytorch (?) self.state = self.state.detach() - if n_episode > 0 and cur_episode >= n_episode: + if n_episode > 0 and cur_episode.sum() >= n_episode: break else: self.buffer.add( @@ -141,13 +159,13 @@ class Collector(object): if batch_size > 0: lens = [len(b) for b in self.buffer] total = sum(lens) - ib = np.random.choice( + batch_index = np.random.choice( total, batch_size, p=np.array(lens) / total) else: - ib = np.array([]) + batch_index = np.array([]) batch_data = Batch() for i, b in enumerate(self.buffer): - cur_batch = (ib == i).sum() + cur_batch = (batch_index == i).sum() if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) diff --git a/tianshou/env/wrapper.py b/tianshou/env/wrapper.py index bef3aab..00aee18 100644 --- a/tianshou/env/wrapper.py +++ b/tianshou/env/wrapper.py @@ -1,6 +1,6 @@ import numpy as np -from abc import ABC from collections import deque +from abc import ABC, abstractmethod from multiprocessing import Process, Pipe try: import ray @@ -63,6 +63,32 @@ class BaseVectorEnv(ABC): self.env_num = len(env_fns) self._reset_after_done = reset_after_done + def is_reset_after_done(self): + return self._reset_after_done + + def __len__(self): + return self.env_num + + @abstractmethod + def reset(self): + pass + + @abstractmethod + def step(self, action): + pass + + @abstractmethod + def seed(self, seed=None): + pass + + @abstractmethod + def render(self): + pass + + @abstractmethod + def close(self): + pass + class VectorEnv(BaseVectorEnv): """docstring for VectorEnv""" @@ -71,9 +97,6 @@ class VectorEnv(BaseVectorEnv): super().__init__(env_fns, reset_after_done) self.envs = [_() for _ in env_fns] - def __len__(self): - return len(self.envs) - def reset(self): return np.stack([e.reset() for e in self.envs]) @@ -148,9 +171,6 @@ class SubprocVectorEnv(BaseVectorEnv): for c in self.child_remote: c.close() - def __len__(self): - return self.env_num - def step(self, action): assert len(action) == self.env_num for p, a in zip(self.parent_remote, action): @@ -203,9 +223,6 @@ class RayVectorEnv(BaseVectorEnv): ray.remote(EnvWrapper).options(num_cpus=0).remote(e()) for e in env_fns] - def __len__(self): - return self.env_num - def step(self, action): assert len(action) == self.env_num result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index c94c5e0..7a31160 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -6,25 +6,21 @@ class BasePolicy(ABC): def __init__(self): super().__init__() + self.model = None @abstractmethod - def act(self, batch, hidden_state=None): + def __call__(self, batch, hidden_state=None): # return Batch(policy, action, hidden) pass - def train(self): - pass - - def eval(self): - pass - - def reset(self): + @abstractmethod + def learn(self, batch): pass def process_fn(self, batch, buffer, indice): return batch - def sync_weights(self): + def sync_weight(self): pass def exploration(self): diff --git a/tianshou/policy/dqn.py b/tianshou/policy/dqn.py index 3179ba6..b3f41e3 100644 --- a/tianshou/policy/dqn.py +++ b/tianshou/policy/dqn.py @@ -1,4 +1,5 @@ import torch +import numpy as np from torch import nn from copy import deepcopy @@ -9,25 +10,86 @@ from tianshou.policy import BasePolicy class DQNPolicy(BasePolicy, nn.Module): """docstring for DQNPolicy""" - def __init__(self, model, discount_factor=0.99, estimation_step=1, + def __init__(self, model, optim, loss, + discount_factor=0.99, + estimation_step=1, use_target_network=True): super().__init__() self.model = model + self.optim = optim + self.loss = loss + self.eps = 0 + assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]' self._gamma = discount_factor + assert estimation_step > 0, 'estimation_step should greater than 0' self._n_step = estimation_step self._target = use_target_network if use_target_network: self.model_old = deepcopy(self.model) + self.model_old.eval() - def act(self, batch, hidden_state=None): - batch_result = Batch() - return batch_result + def __call__(self, batch, hidden_state=None, + model='model', input='obs', eps=None): + model = getattr(self, model) + obs = getattr(batch, input) + q, h = model(obs, hidden_state=hidden_state, info=batch.info) + act = q.max(dim=1)[1].detach().cpu().numpy() + # add eps to act + for i in range(len(q)): + if np.random.rand() < self.eps: + act[i] = np.random.randint(q.shape[1]) + return Batch(Q=q, act=act, state=h) - def sync_weights(self): - if self._use_target_network: - for old, new in zip( - self.model_old.parameters(), self.model.parameters()): - old.data.copy_(new.data) + def set_eps(self, eps): + self.eps = eps + + def train(self): + self.training = True + self.model.train() + + def eval(self): + self.training = False + self.model.eval() + + def sync_weight(self): + if self._target: + self.model_old.load_state_dict(self.model.state_dict()) def process_fn(self, batch, buffer, indice): + returns = np.zeros_like(indice) + gammas = np.zeros_like(indice) + self._n_step + for n in range(self._n_step - 1, -1, -1): + now = (indice + n) % len(buffer) + gammas[buffer.done[now] > 0] = n + returns[buffer.done[now] > 0] = 0 + returns = buffer.rew[now] + self._gamma * returns + terminal = (indice + self._n_step - 1) % len(buffer) + if self._target: + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + a = self(buffer[terminal], input='obs_next').act + target_q = self( + buffer[terminal], model='model_old', input='obs_next').Q + if isinstance(target_q, torch.Tensor): + target_q = target_q.detach().cpu().numpy() + target_q = target_q[np.arange(len(a)), a] + else: + target_q = self(buffer[terminal], input='obs_next').Q + if isinstance(target_q, torch.Tensor): + target_q = target_q.detach().cpu().numpy() + target_q = target_q.max(axis=1) + target_q[gammas != self._n_step] = 0 + returns += (self._gamma ** gammas) * target_q + batch.update(returns=returns) return batch + + def learn(self, batch): + self.optim.zero_grad() + q = self(batch).Q + q = q[np.arange(len(q)), batch.act] + r = batch.returns + if isinstance(r, np.ndarray): + r = torch.tensor(r, device=q.device, dtype=q.dtype) + loss = self.loss(q, r) + loss.backward() + self.optim.step() + return loss.detach().cpu().numpy() diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py index b9e4250..783be29 100644 --- a/tianshou/utils/moving_average.py +++ b/tianshou/utils/moving_average.py @@ -21,3 +21,11 @@ class MovAvg(object): if len(self.cache) == 0: return 0 return np.mean(self.cache) + + def mean(self): + return self.get() + + def std(self): + if len(self.cache) == 0: + return 0 + return np.std(self.cache)