diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 7cf30a2..eb0b796 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -1,7 +1,7 @@ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions -name: Python package +name: Unittest on: push: diff --git a/LICENSE b/LICENSE index 0ccd14a..c94e299 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2019 n+e +Copyright (c) 2020 TSAIL Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/setup.py b/setup.py index fa5d74f..211008d 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ setup( 'gym', 'tqdm', 'numpy', - 'torch', + 'torch>=1.2.0', # for supporting tensorboard 'cloudpickle', 'tensorboard', ], diff --git a/test/test_dqn.py b/test/test_dqn.py index 98d6404..2774e62 100644 --- a/test/test_dqn.py +++ b/test/test_dqn.py @@ -14,7 +14,7 @@ from tianshou.data import Collector, ReplayBuffer class Net(nn.Module): - def __init__(self, layer_num, state_shape, action_shape, device): + def __init__(self, layer_num, state_shape, action_shape, device='cpu'): super().__init__() self.device = device self.model = [ @@ -93,8 +93,9 @@ def test_dqn(args=get_args()): writer = SummaryWriter(args.logdir) best_epoch = -1 best_reward = -1e10 - for epoch in range(args.epoch): - desc = f"Epoch #{epoch + 1}" + start_time = time.time() + for epoch in range(1, 1 + args.epoch): + desc = f"Epoch #{epoch}" # train policy.train() policy.sync_weight() @@ -102,9 +103,9 @@ def test_dqn(args=get_args()): with tqdm.trange( 0, args.step_per_epoch, desc=desc, **tqdm_config) as t: for _ in t: - training_collector.collect(n_step=args.collect_per_step) + result = training_collector.collect( + n_step=args.collect_per_step) global_step += 1 - result = training_collector.stat() loss = policy.learn(training_collector.sample(args.batch_size)) stat_loss.add(loss) writer.add_scalar( @@ -113,26 +114,34 @@ def test_dqn(args=get_args()): 'length', result['length'], global_step=global_step) writer.add_scalar( 'loss', stat_loss.get(), global_step=global_step) + writer.add_scalar( + 'speed', result['speed'], global_step=global_step) t.set_postfix(loss=f'{stat_loss.get():.6f}', reward=f'{result["reward"]:.6f}', - length=f'{result["length"]:.6f}') + length=f'{result["length"]:.6f}', + speed=f'{result["speed"]:.2f}') # eval test_collector.reset_env() test_collector.reset_buffer() policy.eval() policy.set_eps(args.eps_test) - test_collector.collect(n_episode=args.test_num) - result = test_collector.stat() + result = test_collector.collect(n_episode=args.test_num) if best_reward < result['reward']: best_reward = result['reward'] best_epoch = epoch - print(f'Epoch #{epoch + 1} test_reward: {result["reward"]:.6f}, ' + print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, ' f'best_reward: {best_reward:.6f} in #{best_epoch}') if args.task == 'CartPole-v0' and best_reward >= 200: break assert best_reward >= 200 if __name__ == '__main__': - # let's watch its performance! + train_cnt = training_collector.collect_step + test_cnt = test_collector.collect_step + duration = time.time() - start_time + print(f'Collect {train_cnt} training frame and {test_cnt} test frame ' + f'in {duration:.2f}s, ' + f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s') + # Let's watch its performance! env = gym.make(args.task) obs = env.reset() done = False @@ -143,10 +152,9 @@ def test_dqn(args=get_args()): obs, rew, done, info = env.step(action[0].detach().cpu().numpy()) total += rew env.render() - time.sleep(1 / 100) + time.sleep(1 / 35) env.close() print(f'Final test: {total}') - return best_reward if __name__ == '__main__': diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index b19a05f..99d4489 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -15,6 +15,7 @@ class Collector(object): super().__init__() self.env = env self.env_num = 1 + self.collect_step = 0 self.buffer = buffer self.policy = policy self.process_fn = policy.process_fn @@ -40,6 +41,7 @@ class Collector(object): self.state = None self.stat_reward = MovAvg(stat_size) self.stat_length = MovAvg(stat_size) + self.stat_speed = MovAvg(stat_size) def reset_buffer(self): if self._multi_buf: @@ -78,6 +80,8 @@ class Collector(object): return [data] def collect(self, n_step=0, n_episode=0, render=0): + start_time = time.time() + start_step = self.collect_step assert sum([(n_step > 0), (n_episode > 0)]) == 1,\ "One and only one collection number specification permitted!" cur_step = 0 @@ -119,9 +123,11 @@ class Collector(object): elif self._multi_buf: self.buffer[i].add(**data) cur_step += 1 + self.collect_step += 1 else: self.buffer.add(**data) cur_step += 1 + self.collect_step += 1 if self._done[i]: cur_episode[i] += 1 self.stat_reward.add(self.reward[i]) @@ -130,6 +136,7 @@ class Collector(object): if self._cached_buf: self.buffer.update(self._cached_buf[i]) cur_step += len(self._cached_buf[i]) + self.collect_step += len(self._cached_buf[i]) self._cached_buf[i].reset() if isinstance(self.state, list): self.state[i] = None @@ -145,6 +152,7 @@ class Collector(object): self._obs, self._act[0], self._rew, self._done, obs_next, self._info) cur_step += 1 + self.collect_step += 1 if self._done: cur_episode += 1 self.stat_reward.add(self.reward) @@ -158,6 +166,13 @@ class Collector(object): break self._obs = obs_next self._obs = obs_next + self.stat_speed.add((self.collect_step - start_step) / ( + time.time() - start_time)) + return { + 'reward': self.stat_reward.get(), + 'length': self.stat_length.get(), + 'speed': self.stat_speed.get(), + } def sample(self, batch_size): if self._multi_buf: @@ -179,9 +194,3 @@ class Collector(object): batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data - - def stat(self): - return { - 'reward': self.stat_reward.get(), - 'length': self.stat_length.get(), - } diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 7a31160..680d803 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -10,7 +10,7 @@ class BasePolicy(ABC): @abstractmethod def __call__(self, batch, hidden_state=None): - # return Batch(policy, action, hidden) + # return Batch(act=np.array(), state=None, ...) pass @abstractmethod @@ -22,6 +22,3 @@ class BasePolicy(ABC): def sync_weight(self): pass - - def exploration(self): - pass