diff --git a/.gitignore b/.gitignore index b55c2b1..e795259 100644 --- a/.gitignore +++ b/.gitignore @@ -6,5 +6,4 @@ parameters *.sublime* checkpoints checkpoints_origin - - +*.json diff --git a/README.md b/README.md index 674c4c7..d6d4ada 100644 --- a/README.md +++ b/README.md @@ -44,9 +44,14 @@ Tianshou(天授) is a reinforcement learning platform. The following image illus ## About coding style -You can follow [google python coding style](https://google.github.io/styleguide/pyguide.html) +Please follow [google python coding style](https://google.github.io/styleguide/pyguide.html) + +Files should all be named with lower case letters and underline. + +Try to use full names. Don't use too many abbrevations for class/function/variable names except common abbrevations (such as `num` for number, `dim` for dimension, `env` for environment, `op` for operation). For now we use `pi` to refer to the policy in examples/ppo_example.py. + +The """xxx""" comment should be written right after class/function. Also comment the part that's not intuitive during the code. We must comment, but for now we don't need to polish them. -The file should all be named with lower case letters and underline. ## TODO Search based method parallel. diff --git a/examples/ppo_example.py b/examples/ppo_example.py index 666fb5e..d6affbf 100755 --- a/examples/ppo_example.py +++ b/examples/ppo_example.py @@ -8,64 +8,64 @@ import gym import sys sys.path.append('..') import tianshou.core.losses as losses -from tianshou.data.Batch import Batch -import tianshou.data.adv_estimate as adv_estimate +from tianshou.data.batch import Batch +import tianshou.data.advantage_estimation as advantage_estimation import tianshou.core.policy as policy -def policy_net(obs, act_dim, scope=None): +def policy_net(observation, action_dim, scope=None): """ Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf - :param obs: Placeholder for the observation. A tensor of shape (bs, x, y, channels) - :param act_dim: int. The number of actions. + :param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels) + :param action_dim: int. The number of actions. :param scope: str. Specifying the scope of the variables. """ # with tf.variable_scope(scope): - net = tf.layers.conv2d(obs, 16, 8, 4, 'valid', activation=tf.nn.relu) + net = tf.layers.conv2d(observation, 16, 8, 4, 'valid', activation=tf.nn.relu) net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu) net = tf.layers.flatten(net) net = tf.layers.dense(net, 256, activation=tf.nn.relu) - act_logits = tf.layers.dense(net, act_dim) + act_logits = tf.layers.dense(net, action_dim) return act_logits if __name__ == '__main__': # a clean version with only policy net, no value net env = gym.make('PongNoFrameskip-v4') - obs_dim = env.observation_space.shape - act_dim = env.action_space.n + observation_dim = env.observation_space.shape + action_dim = env.action_space.n clip_param = 0.2 - nb_batches = 2 + num_batches = 2 # 1. build network with pure tf - obs = tf.placeholder(tf.float32, shape=(None,) + obs_dim) # network input + observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input with tf.variable_scope('pi'): - act_logits = policy_net(obs, act_dim, 'pi') + action_logits = policy_net(observation, action_dim, 'pi') train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES with tf.variable_scope('pi_old'): - act_logits_old = policy_net(obs, act_dim, 'pi_old') + action_logits_old = policy_net(observation, action_dim, 'pi_old') # 2. build losses, optimizers - pi = policy.OnehotCategorical(act_logits, obs_placeholder=obs) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc. + pi = policy.OnehotCategorical(action_logits, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc. # for continuous action space, you may need to change an environment to run - pi_old = policy.OnehotCategorical(act_logits_old, obs_placeholder=obs) + pi_old = policy.OnehotCategorical(action_logits_old, observation_placeholder=observation) - act = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions - Dgrad = tf.placeholder(dtype=tf.float32, shape=[None]) # values used in the Gradients + action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions + advantage = tf.placeholder(dtype=tf.float32, shape=[None]) # advantage values used in the Gradients - ppo_loss_clip = losses.ppo_clip(act, Dgrad, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict + ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict total_loss = ppo_loss_clip optimizer = tf.train.AdamOptimizer(1e-3) train_op = optimizer.minimize(total_loss, var_list=train_var_list) # 3. define data collection - training_data = Batch(env, pi, adv_estimate.full_return) # YouQiaoben: finish and polish Batch, adv_estimate.gae_lambda as in PPO paper - # ShihongSong: Replay(env, pi, adv_estimate.target_network), use your ReplayMemory, interact as follows. Simplify your adv_estimate.dqn to run before YongRen's DQN + training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper + # ShihongSong: Replay(env, pi, advantage_estimation.target_network), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN # maybe a dict to manage the elements to be collected # 4. start training @@ -81,9 +81,9 @@ if __name__ == '__main__': # a clean version with only policy net, no value net print('Collected {} times.'.format(collection_count)) # update network - for _ in range(nb_batches): + for _ in range(num_batches): data = training_data.next_batch(64) # YouQiaoben, ShihongSong # TODO: auto managing of the placeholders? or add this to params of data.Batch - sess.run(train_op, feed_dict={obs: data['obs'], act: data['acs'], Dgrad: data['Gts']}) + sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']}) minibatch_count += 1 print('Trained {} minibatches.'.format(minibatch_count)) \ No newline at end of file diff --git a/tianshou/core/global_config.json b/tianshou/core/global_config.json deleted file mode 100644 index 19d0ca3..0000000 --- a/tianshou/core/global_config.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "global_description": "read by Environment, Neural Network, and MCTS", - "state_space": " ", - "action_space": " " -} diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index c38168f..5e127c2 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -1,29 +1,24 @@ import tensorflow as tf -import baselines.common.tf_util as U -def ppo_clip(sampled_action, Dgrad, clip_param, pi, pi_old): +def ppo_clip(sampled_action, advantage, clip_param, pi, pi_old): + """ + the clip loss in ppo paper + + :param sampled_action: placeholder of sampled actions during interaction with the environment + :param advantage: placeholder of estimated advantage values. + :param clip param: float or Tensor of type float. + :param pi: current `policy` to be optimized + :param pi_old: old `policy` for computing the ppo loss as in Eqn. (7) in the paper + """ + log_pi_act = pi.log_prob(sampled_action) log_pi_old_act = pi_old.log_prob(sampled_action) ratio = tf.exp(log_pi_act - log_pi_old_act) clipped_ratio = tf.clip_by_value(ratio, 1. - clip_param, 1. + clip_param) - ppo_clip_loss = -tf.reduce_mean(tf.minimum(ratio * Dgrad, clipped_ratio * Dgrad)) + ppo_clip_loss = -tf.reduce_mean(tf.minimum(ratio * advantage, clipped_ratio * advantage)) return ppo_clip_loss -def L_VF(Gt, pi, St): # TODO: do we really have to specify St, or it's implicit in policy/value net - return U.mean(tf.square(pi.vpred - Gt)) - - -def entropy_reg(pi): - return - U.mean(pi.pd.entropy()) - - -def KL_diff(pi, pi_old): - kloldnew = pi_old.pd.kl(pi.pd) - meankl = U.mean(kloldnew) - return meankl - - def vanilla_policy_gradient(): pass diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index 2484fe9..a61661c 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -83,14 +83,14 @@ class StochasticPolicy(object): act_dtype, param_dtype, is_continuous, - obs_placeholder, + observation_placeholder, group_ndims=0, # maybe useful for repeat_action **kwargs): self._act_dtype = act_dtype self._param_dtype = param_dtype self._is_continuous = is_continuous - self._obs_placeholder = obs_placeholder + self._observation_placeholder = observation_placeholder if isinstance(group_ndims, int): if group_ndims < 0: raise ValueError("group_ndims must be non-negative.") diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index 694521d..822600a 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -39,7 +39,7 @@ class OnehotCategorical(StochasticPolicy): `[i, j, ..., k, :]` is a one-hot vector of the selected category. """ - def __init__(self, logits, obs_placeholder, dtype=None, group_ndims=0, **kwargs): + def __init__(self, logits, observation_placeholder, dtype=None, group_ndims=0, **kwargs): self._logits = tf.convert_to_tensor(logits) if dtype is None: @@ -53,7 +53,7 @@ class OnehotCategorical(StochasticPolicy): act_dtype=dtype, param_dtype=self._logits.dtype, is_continuous=False, - obs_placeholder=obs_placeholder, + observation_placeholder=observation_placeholder, group_ndims=group_ndims, **kwargs) @@ -69,7 +69,7 @@ class OnehotCategorical(StochasticPolicy): def _act(self, observation): sess = tf.get_default_session() # TODO: this may be ugly. also maybe huge problem when parallel - sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), feed_dict={self._obs_placeholder: observation[None]}) + sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), feed_dict={self._observation_placeholder: observation[None]}) sampled_action = sampled_action[0, 0] diff --git a/tianshou/core/policy_value.json b/tianshou/core/policy_value.json deleted file mode 100644 index e69de29..0000000 diff --git a/tianshou/core/state_mask.json b/tianshou/core/state_mask.json deleted file mode 100644 index 1d934fe..0000000 --- a/tianshou/core/state_mask.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "state" : "10", - "mask" : "1000" -} diff --git a/tianshou/data/Batch.py b/tianshou/data/Batch.py deleted file mode 100644 index 6b33c1b..0000000 --- a/tianshou/data/Batch.py +++ /dev/null @@ -1,128 +0,0 @@ -import numpy as np -import gc - - -# TODO: Refactor with tf.train.slice_input_producer, tf.train.Coordinator, tf.train.QueueRunner -class Batch(object): - """ - class for batch datasets. Collect multiple states (actions, rewards, etc.) on-policy. - """ - - def __init__(self, env, pi, adv_estimation_func): # how to name the function? - self.env = env - self.pi = pi - self.adv_estimation_func = adv_estimation_func - self.is_first_collect = True - - - def collect(self, num_timesteps=0, num_episodes=0, apply_func=True): # specify how many data to collect here, or fix it in __init__() - assert sum([num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!" - - if num_timesteps > 0: # YouQiaoben: finish this implementation, the following code are just from openai/baselines - t = 0 - ac = self.env.action_space.sample() # not used, just so we have the datatype - new = True # marks if we're on first timestep of an episode - if self.is_first_collect: - ob = self.env.reset() - self.is_first_collect = False - else: - ob = self.raw_data['obs'][0] # last observation! - - # Initialize history arrays - obs = np.array([ob for _ in range(num_timesteps)]) - rews = np.zeros(num_timesteps, 'float32') - news = np.zeros(num_timesteps, 'int32') - acs = np.array([ac for _ in range(num_timesteps)]) - - for t in range(num_timesteps): - pass - - while True: - prevac = ac - ac, vpred = pi.act(stochastic, ob) - # Slight weirdness here because we need value function at time T - # before returning segment [0, T-1] so we get the correct - # terminal value - i = t % horizon - obs[i] = ob - vpreds[i] = vpred - news[i] = new - acs[i] = ac - prevacs[i] = prevac - - ob, rew, new, _ = env.step(ac) - rews[i] = rew - - cur_ep_ret += rew - cur_ep_len += 1 - if new: - ep_rets.append(cur_ep_ret) - ep_lens.append(cur_ep_len) - cur_ep_ret = 0 - cur_ep_len = 0 - ob = env.reset() - t += 1 - - if num_episodes > 0: # YouQiaoben: fix memory growth, both del and gc.collect() fail - # initialize rawdata lists - if not self.is_first_collect: - del self.obs - del self.acs - del self.rews - del self.news - - obs = [] - acs = [] - rews = [] - news = [] - - t_count = 0 - - for e in range(num_episodes): - ob = self.env.reset() - obs.append(ob) - news.append(True) - - while True: - ac = self.pi.act(ob) - acs.append(ac) - - ob, rew, done, _ = self.env.step(ac) - rews.append(rew) - - t_count += 1 - if t_count >= 200: # force episode stop - break - - if done: # end of episode, discard s_T - break - else: - obs.append(ob) - news.append(False) - - self.obs = np.array(obs) - self.acs = np.array(acs) - self.rews = np.array(rews) - self.news = np.array(news) - - del obs - del acs - del rews - del news - - self.raw_data = {'obs': self.obs, 'acs': self.acs, 'rews': self.rews, 'news': self.news} - - self.is_first_collect = False - - if apply_func: - self.apply_adv_estimation_func() - - gc.collect() - - def apply_adv_estimation_func(self): - self.data = self.adv_estimation_func(self.raw_data) - - def next_batch(self, batch_size): # YouQiaoben: referencing other iterate over batches - rand_idx = np.random.choice(self.data['obs'].shape[0], batch_size) - return {key: value[rand_idx] for key, value in self.data.items()} - diff --git a/tianshou/data/adv_estimate.py b/tianshou/data/adv_estimate.py deleted file mode 100644 index fa91351..0000000 --- a/tianshou/data/adv_estimate.py +++ /dev/null @@ -1,37 +0,0 @@ -import numpy as np - - -def full_return(raw_data): - """ - naively compute full return - :param raw_data: dict of specified keys and values. - """ - obs = raw_data['obs'] - acs = raw_data['acs'] - rews = raw_data['rews'] - news = raw_data['news'] - num_timesteps = rews.shape[0] - - data = {} - data['obs'] = obs - data['acs'] = acs - - Gts = rews.copy() - episode_start_idx = 0 - for i in range(1, num_timesteps): - if news[i] or (i == num_timesteps - 1): # found one full episode - if i < rews.shape[0] - 1: - t = i - 1 - else: - t = i - Gt = 0 - while t >= episode_start_idx: - Gt += rews[t] - Gts[t] = Gt - t -= 1 - - episode_start_idx = i - - data['Gts'] = Gts - - return data \ No newline at end of file diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py new file mode 100644 index 0000000..6f5b8a6 --- /dev/null +++ b/tianshou/data/advantage_estimation.py @@ -0,0 +1,37 @@ +import numpy as np + + +def full_return(raw_data): + """ + naively compute full return + :param raw_data: dict of specified keys and values. + """ + observations = raw_data['observations'] + actions = raw_data['actions'] + rewards = raw_data['rewards'] + episode_start_flags = raw_data['episode_start_flags'] + num_timesteps = rewards.shape[0] + + data = {} + data['observations'] = observations + data['actions'] = actions + + returns = rewards.copy() + episode_start_idx = 0 + for i in range(1, num_timesteps): + if episode_start_flags[i] or (i == num_timesteps - 1): # found the start of next episode or the end of all episodes + if i < rewards.shape[0] - 1: + t = i - 1 + else: + t = i + Gt = 0 + while t >= episode_start_idx: + Gt += rewards[t] + returns[t] = Gt + t -= 1 + + episode_start_idx = i + + data['returns'] = returns + + return data \ No newline at end of file diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py new file mode 100644 index 0000000..7b28966 --- /dev/null +++ b/tianshou/data/batch.py @@ -0,0 +1,128 @@ +import numpy as np +import gc + + +# TODO: Refactor with tf.train.slice_input_producer, tf.train.Coordinator, tf.train.QueueRunner +class Batch(object): + """ + class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy. + """ + + def __init__(self, env, pi, advantage_estimation_function): # how to name the function? + self._env = env + self._pi = pi + self._advantage_estimation_function = advantage_estimation_function + self._is_first_collect = True + + + def collect(self, num_timesteps=0, num_episodes=0, apply_function=True): # specify how many data to collect here, or fix it in __init__() + assert sum([num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!" + + if num_timesteps > 0: # YouQiaoben: finish this implementation, the following code are just from openai/baselines + t = 0 + ac = self.env.action_space.sample() # not used, just so we have the datatype + new = True # marks if we're on first timestep of an episode + if self.is_first_collect: + ob = self.env.reset() + self.is_first_collect = False + else: + ob = self.raw_data['observations'][0] # last observation! + + # Initialize history arrays + observations = np.array([ob for _ in range(num_timesteps)]) + rewards = np.zeros(num_timesteps, 'float32') + episode_start_flags = np.zeros(num_timesteps, 'int32') + actions = np.array([ac for _ in range(num_timesteps)]) + + for t in range(num_timesteps): + pass + + while True: + prevac = ac + ac, vpred = pi.act(stochastic, ob) + # Slight weirdness here because we need value function at time T + # before returning segment [0, T-1] so we get the correct + # terminal value + i = t % horizon + observations[i] = ob + vpreds[i] = vpred + episode_start_flags[i] = new + actions[i] = ac + prevacs[i] = prevac + + ob, rew, new, _ = env.step(ac) + rewards[i] = rew + + cur_ep_ret += rew + cur_ep_len += 1 + if new: + ep_rets.append(cur_ep_ret) + ep_lens.append(cur_ep_len) + cur_ep_ret = 0 + cur_ep_len = 0 + ob = env.reset() + t += 1 + + if num_episodes > 0: # YouQiaoben: fix memory growth, both del and gc.collect() fail + # initialize rawdata lists + if not self._is_first_collect: + del self.observations + del self.actions + del self.rewards + del self.episode_start_flags + + observations = [] + actions = [] + rewards = [] + episode_start_flags = [] + + t_count = 0 + + for _ in range(num_episodes): + ob = self._env.reset() + observations.append(ob) + episode_start_flags.append(True) + + while True: + ac = self._pi.act(ob) + actions.append(ac) + + ob, reward, done, _ = self._env.step(ac) + rewards.append(reward) + + t_count += 1 + if t_count >= 200: # force episode stop, just to test if memory still grows + break + + if done: # end of episode, discard s_T + break + else: + observations.append(ob) + episode_start_flags.append(False) + + self.observations = np.array(observations) + self.actions = np.array(actions) + self.rewards = np.array(rewards) + self.episode_start_flags = np.array(episode_start_flags) + + del observations + del actions + del rewards + del episode_start_flags + + self.raw_data = {'observations': self.observations, 'actions': self.actions, 'rewards': self.rewards, 'episode_start_flags': self.episode_start_flags} + + self._is_first_collect = False + + if apply_function: + self.apply_advantage_estimation_function() + + gc.collect() + + def apply_advantage_estimation_function(self): + self.data = self._advantage_estimation_function(self.raw_data) + + def next_batch(self, batch_size): # YouQiaoben: referencing other iterate over batches + rand_idx = np.random.choice(self.data['observations'].shape[0], batch_size) + return {key: value[rand_idx] for key, value in self.data.items()} +