diff --git a/examples/dqn_replay.py b/examples/dqn_replay.py index b9a5614..9fb8b4f 100644 --- a/examples/dqn_replay.py +++ b/examples/dqn_replay.py @@ -3,6 +3,8 @@ import tensorflow as tf import gym import numpy as np import time +import logging +logging.basicConfig(level=logging.INFO) # our lib imports here! It's ok to append path in examples import sys @@ -12,8 +14,9 @@ import tianshou.data.advantage_estimation as advantage_estimation import tianshou.core.policy.dqn as policy import tianshou.core.value_function.action_value as value_function -from tianshou.data.replay_buffer.vanilla import VanillaReplayBuffer +from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer from tianshou.data.data_collector import DataCollector +from tianshou.data.tester import test_policy_in_env if __name__ == '__main__': @@ -33,7 +36,7 @@ if __name__ == '__main__': return None, action_values # no policy head ### 2. build policy, loss, optimizer - dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, weight_update=200) + dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, weight_update=800) pi = policy.DQN(dqn) dqn_loss = losses.qlearning(dqn) @@ -43,7 +46,7 @@ if __name__ == '__main__': train_op = optimizer.minimize(total_loss, var_list=dqn.trainable_variables) ### 3. define data collection - replay_buffer = VanillaReplayBuffer(capacity=1e5, nstep=1) + replay_buffer = VanillaReplayBuffer(capacity=2e4, nstep=1) process_functions = [advantage_estimation.nstep_q_return(1, dqn)] managed_networks = [dqn] @@ -58,11 +61,11 @@ if __name__ == '__main__': ### 4. start training # hyper-parameters - batch_size = 256 + batch_size = 128 replay_buffer_warmup = 1000 - epsilon_decay_interval = 200 - epsilon = 0.3 - test_interval = 1000 + epsilon_decay_interval = 500 + epsilon = 0.6 + test_interval = 5000 seed = 0 np.random.seed(seed) @@ -74,11 +77,11 @@ if __name__ == '__main__': sess.run(tf.global_variables_initializer()) # assign actor to pi_old - pi.sync_weights() # TODO: automate this for policies with target network + pi.sync_weights() # TODO: rethink and redesign target network interface start_time = time.time() pi.set_epsilon_train(epsilon) - data_collector.collect(num_timesteps=replay_buffer_warmup) # warm-up + data_collector.collect(num_timesteps=replay_buffer_warmup) # TODO: uniform random warm-up for i in range(int(1e8)): # number of training steps # anneal epsilon step-wise if (i + 1) % epsilon_decay_interval == 0 and epsilon > 0.1: @@ -86,7 +89,7 @@ if __name__ == '__main__': pi.set_epsilon_train(epsilon) # collect data - data_collector.collect() + data_collector.collect(num_timesteps=4) # update network feed_dict = data_collector.next_batch(batch_size) @@ -95,7 +98,7 @@ if __name__ == '__main__': # test every 1000 training steps # tester could share some code with batch! if i % test_interval == 0: - print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) + print('Step {}, elapsed time: {:.1f} min'.format(i, (time.time() - start_time) / 60)) # epsilon 0.05 as in nature paper pi.set_epsilon_test(0.05) - #test(env, pi) # go for act_test of pi, not act + test_policy_in_env(pi, env, num_timesteps=1000) diff --git a/examples/ppo_cartpole.py b/examples/ppo_cartpole.py index 0054c9a..bd8ab72 100755 --- a/examples/ppo_cartpole.py +++ b/examples/ppo_cartpole.py @@ -11,15 +11,18 @@ import argparse import sys sys.path.append('..') from tianshou.core import losses -from tianshou.data.batch import Batch import tianshou.data.advantage_estimation as advantage_estimation -import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy +import tianshou.core.policy.stochastic as policy + +from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer +from tianshou.data.data_collector import DataCollector if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--render", action="store_true", default=False) args = parser.parse_args() + env = gym.make('CartPole-v0') observation_dim = env.observation_space.shape action_dim = env.action_space.n @@ -59,7 +62,7 @@ if __name__ == '__main__': train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) ### 3. define data collection - training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render=args.render) + data_collector = Batch(env, pi, [advantage_estimation.full_return], [pi], render=args.render) ### 4. start training config = tf.ConfigProto() @@ -73,15 +76,15 @@ if __name__ == '__main__': start_time = time.time() for i in range(100): # collect data - training_data.collect(num_episodes=50) + data_collector.collect(num_episodes=50) # print current return print('Epoch {}:'.format(i)) - training_data.statistics() + data_collector.statistics() # update network for _ in range(num_batches): - feed_dict = training_data.next_batch(batch_size) + feed_dict = data_collector.next_batch(batch_size) sess.run(train_op, feed_dict=feed_dict) # assigning actor to pi_old diff --git a/tianshou/data/replay_buffer/__init__.py b/tianshou/data/data_buffer/__init__.py similarity index 100% rename from tianshou/data/replay_buffer/__init__.py rename to tianshou/data/data_buffer/__init__.py diff --git a/tianshou/data/replay_buffer/base.py b/tianshou/data/data_buffer/base.py similarity index 54% rename from tianshou/data/replay_buffer/base.py rename to tianshou/data/data_buffer/base.py index a779344..bf377d9 100644 --- a/tianshou/data/replay_buffer/base.py +++ b/tianshou/data/data_buffer/base.py @@ -1,13 +1,13 @@ -class ReplayBufferBase(object): +class DataBufferBase(object): """ - base class for replay buffer. + base class for data buffer, including replay buffer as in DQN and batched dataset as in on-policy algos """ def add(self, frame): raise NotImplementedError() - def remove(self): + def clear(self): raise NotImplementedError() def sample(self, batch_size): diff --git a/tianshou/data/data_buffer/batch_set.py b/tianshou/data/data_buffer/batch_set.py new file mode 100644 index 0000000..57810f2 --- /dev/null +++ b/tianshou/data/data_buffer/batch_set.py @@ -0,0 +1,24 @@ +from .base import DataBufferBase + + +class BatchSet(DataBufferBase): + """ + class for batched dataset as used in on-policy algos + """ + def __init__(self): + self.data = [[]] + self.index = [[]] + self.candidate_index = 0 + + self.size = 0 # number of valid data points (not frames) + + self.index_lengths = [0] # for sampling + + def add(self, frame): + self.data[-1].append(frame) + + def clear(self): + pass + + def sample(self, batch_size): + pass diff --git a/tianshou/data/data_buffer/replay_buffer_base.py b/tianshou/data/data_buffer/replay_buffer_base.py new file mode 100644 index 0000000..dd437ae --- /dev/null +++ b/tianshou/data/data_buffer/replay_buffer_base.py @@ -0,0 +1,12 @@ +from .base import DataBufferBase + +class ReplayBufferBase(DataBufferBase): + """ + base class for replay buffer. + """ + def remove(self): + """ + when size exceeds capacity, removes extra data points + :return: + """ + raise NotImplementedError() diff --git a/tianshou/data/replay_buffer/vanilla.py b/tianshou/data/data_buffer/vanilla.py similarity index 97% rename from tianshou/data/replay_buffer/vanilla.py rename to tianshou/data/data_buffer/vanilla.py index 9996278..40d88dd 100644 --- a/tianshou/data/replay_buffer/vanilla.py +++ b/tianshou/data/data_buffer/vanilla.py @@ -1,13 +1,14 @@ import logging import numpy as np -from .base import ReplayBufferBase +from .replay_buffer_base import ReplayBufferBase STATE = 0 ACTION = 1 REWARD = 2 DONE = 3 +# TODO: valid data points could be less than `nstep` timesteps class VanillaReplayBuffer(ReplayBufferBase): """ vanilla replay buffer as used in (Mnih, et al., 2015). diff --git a/tianshou/data/data_collector.py b/tianshou/data/data_collector.py index 9834761..c3e1879 100644 --- a/tianshou/data/data_collector.py +++ b/tianshou/data/data_collector.py @@ -3,7 +3,8 @@ import logging import itertools import sys -from .replay_buffer.base import ReplayBufferBase +from .data_buffer.replay_buffer_base import ReplayBufferBase +from .data_buffer.batch_set import BatchSet class DataCollector(object): """ @@ -31,10 +32,13 @@ class DataCollector(object): self.current_observation = self.env.reset() - def collect(self, num_timesteps=1, num_episodes=0, my_feed_dict={}): + def collect(self, num_timesteps=1, num_episodes=0, my_feed_dict={}, auto_clear=True): assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\ "One and only one collection number specification permitted!" + if isinstance(self.data_buffer, BatchSet) and auto_clear: + self.data_buffer.clear() + if num_timesteps > 0: num_timesteps_ = int(num_timesteps) for _ in range(num_timesteps_): diff --git a/tianshou/data/test_replay_buffer.py b/tianshou/data/test_replay_buffer.py index 55c6ef2..b3ffd0c 100644 --- a/tianshou/data/test_replay_buffer.py +++ b/tianshou/data/test_replay_buffer.py @@ -1,6 +1,6 @@ import numpy as np -from replay_buffer.vanilla import VanillaReplayBuffer +from data_buffer.vanilla import VanillaReplayBuffer capacity = 12 nstep = 3 diff --git a/tianshou/data/tester.py b/tianshou/data/tester.py index 2a2e407..7f55ab3 100644 --- a/tianshou/data/tester.py +++ b/tianshou/data/tester.py @@ -1,8 +1,75 @@ from __future__ import absolute_import +import gym +import logging +import numpy as np + + +def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_factor=0.99): + + assert sum([num_episodes > 0, num_timesteps > 0]) == 1, \ + 'One and only one collection number specification permitted!' -def test_policy_in_env(policy, env): # make another env as the original is for training data collection - env_ = env + env_id = env.spec.id + env_ = gym.make(env_id) - pass \ No newline at end of file + # test policy + returns = [] + undiscounted_returns = [] + current_return = 0. + current_undiscounted_return = 0. + + if num_episodes > 0: + returns = [0.] * num_episodes + undiscounted_returns = [0.] * num_episodes + for i in range(num_episodes): + current_return = 0. + current_undiscounted_return = 0. + current_discount = 1. + observation = env_.reset() + done = False + while not done: + action = policy.act_test(observation) + observation, reward, done, _ = env_.step(action) + current_return += reward * current_discount + current_undiscounted_return += reward + current_discount *= discount_factor + + returns[i] = current_return + undiscounted_returns[i] = current_undiscounted_return + + # run for fix number of timesteps, only the first episode and finished episodes + # matters when calcuting average return + if num_timesteps > 0: + current_discount = 1. + observation = env_.reset() + for _ in range(num_timesteps): + action = policy.act_test(observation) + observation, reward, done, _ = env_.step(action) + current_return += reward * current_discount + current_undiscounted_return += reward + current_discount *= discount_factor + + if done: + returns.append(current_return) + undiscounted_returns.append(current_undiscounted_return) + current_return = 0. + current_undiscounted_return = 0. + current_discount = 1. + observation = env_.reset() + + # log + if returns: # has at least one finished episode + mean_return = np.mean(returns) + mean_undiscounted_return = np.mean(undiscounted_returns) + else: # the first episode is too long to finish + logging.warning('The first test episode is still not finished after {} timesteps. ' + 'Logging its return anyway.'.format(num_timesteps)) + mean_return = current_return + mean_undiscounted_return = current_undiscounted_return + logging.info('Mean return: {}'.format(mean_return)) + logging.info('Mean undiscounted return: {}'.format(mean_undiscounted_return)) + + # clear scene + env_.close() \ No newline at end of file