From 8c108174b6c3e8cb3bd395507f4ce052e724ee11 Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Sun, 15 Apr 2018 11:46:46 +0800 Subject: [PATCH] some more API docs --- tianshou/core/README.md | 28 ---- tianshou/core/random.py | 52 ++++++- tianshou/core/utils.py | 26 +++- tianshou/data/batch.py | 235 -------------------------------- tianshou/data/data_collector.py | 10 +- 5 files changed, 78 insertions(+), 273 deletions(-) delete mode 100644 tianshou/core/README.md delete mode 100644 tianshou/data/batch.py diff --git a/tianshou/core/README.md b/tianshou/core/README.md deleted file mode 100644 index a9cda58..0000000 --- a/tianshou/core/README.md +++ /dev/null @@ -1,28 +0,0 @@ -#TODO: - -Separate actor and critic. (Important, we need to focus on that recently) - -# policy - -YongRen - -### base, stochastic - -follow OnehotCategorical to write Gaussian, can be in the same file as stochastic.py - -### deterministic - -not sure how to write, but should at least have act() method to interact with environment - -referencing QValuePolicy in base.py, should have at least the listed methods. - - -# losses - -TongzhengRen - -seems to be direct python functions. Though the management of placeholders may require some discussion. also may write it in a functional form. - -# policy, value_function - -naming should be reconsidered. Perhaps use plural forms for all nouns \ No newline at end of file diff --git a/tianshou/core/random.py b/tianshou/core/random.py index fe6e5c7..4a00914 100644 --- a/tianshou/core/random.py +++ b/tianshou/core/random.py @@ -7,11 +7,26 @@ import numpy as np class RandomProcess(object): + """ + Base class for random process for exploration in the environment. + """ def reset_states(self): + """ + Reset the internal states, if any, of the random process. Does nothing by default. + """ pass class AnnealedGaussianProcess(RandomProcess): + """ + Class for annealed Gaussian process, annealing the sigma in the Gaussian-like distribution along sampling. + At each timestep, the class samples from a Gaussian-like distribution. + + :param mu: A float. Specifying the mean of the Gaussian-like distribution. + :param sigma: A float. Specifying the std of teh Gaussian-like distribution. + :param sigma_min: A float. Specifying the minimum std until which the annealing stops. + :param n_steps_annealing: An int. It specifies the total number of steps for which the annealing happens. + """ def __init__(self, mu, sigma, sigma_min, n_steps_annealing): self.mu = mu self.sigma = sigma @@ -28,11 +43,24 @@ class AnnealedGaussianProcess(RandomProcess): @property def current_sigma(self): + """The current sigma after potential annealing.""" sigma = max(self.sigma_min, self.m * float(self.n_steps) + self.c) return sigma class GaussianWhiteNoiseProcess(AnnealedGaussianProcess): + """ + Class for Gaussian white noise. At each timestep, the class samples from an exact Gaussian distribution. + It allows annealing in the std of the Gaussian, but the distribution is independent at different timesteps. + + :param mu: A float defaulting to 0. Specifying the mean of the Gaussian-like distribution. + :param sigma: A float defaulting to 1. Specifying the std of the Gaussian-like distribution. + :param sigma_min: Optional. A float. Specifying the minimum std until which the annealing stops. It defaults to + ``None`` where no annealing takes place. + :param n_steps_annealing: Optional. An int. It specifies the total number of steps for which the annealing happens. + Only effective when ``sigma_mean`` is not ``None``. + :param size: An int or tuple of ints. It corresponds to the shape of the action of the environment. + """ def __init__(self, mu=0., sigma=1., sigma_min=None, n_steps_annealing=1000, size=1): super(GaussianWhiteNoiseProcess, self).__init__(mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing) self.size = size @@ -42,8 +70,27 @@ class GaussianWhiteNoiseProcess(AnnealedGaussianProcess): self.n_steps += 1 return sample -# Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab + class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess): + """ + Class for Ornstein-Uhlenbeck Process, as used for exploration in DDPG. Implemented based on + http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab . + It basically is a temporal-correlated Gaussian process where the distribution at the current timestep depends on + the samples from the last timestep. It's not exactly Gaussian but still resembles Gaussian. + + :param theta: A float. A special parameter for this process. + :param mu: A float. Another parameter of this process, but it's not exactly the mean of the distribution. + :param sigma: A float. Another parameter of this process. It acts like the std of the Gaussian-like distribution + to some extent. + :param dt: A float. The time interval to simulate this process discretely, as the process is mathematically defined + to be a continuous one. + :param x0: Optional. A float. The initial value of "the samples from the last timestep" so as to draw the first + sample. It defaults to zero. + :param size: An int or tuple of ints. It corresponds to the shape of the action of the environment. + :param sigma_min: Optional. A float. Specifying the minimum std until which the annealing stops. It defaults to + ``None`` where no annealing takes place. + :param n_steps_annealing: An int. It specifies the total number of steps for which the annealing happens. + """ def __init__(self, theta, mu=0., sigma=1., dt=1e-2, x0=None, size=1, sigma_min=None, n_steps_annealing=1000): super(OrnsteinUhlenbeckProcess, self).__init__( mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing) @@ -55,7 +102,8 @@ class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess): self.reset_states() def sample(self): - x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size) + x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + \ + self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size) self.x_prev = x self.n_steps += 1 return x diff --git a/tianshou/core/utils.py b/tianshou/core/utils.py index 9a83f77..a8fcc9d 100644 --- a/tianshou/core/utils.py +++ b/tianshou/core/utils.py @@ -3,10 +3,12 @@ import tensorflow as tf def identify_dependent_variables(tensor, candidate_variables): """ - identify the variables that `tensor` depends on - :param tensor: A Tensor. - :param candidate_variables: A list of Variables. - :return: A list of variables in `candidate variables` that has effect on `tensor` + Identify and return the variables in ``candidate_variables`` that ``tensor`` depends on. + + :param tensor: A Tensor. The target Tensor to identify dependency. + :param candidate_variables: A list of :class:`tf.Variable` s. The candidate Variables to identify dependency. + + :return: A list of :class:`tf.Variable` s in ``candidate variables`` that has effect on ``tensor``. """ grads = tf.gradients(tensor, candidate_variables) return [var for var, grad in zip(candidate_variables, grads) if grad is not None] @@ -14,10 +16,20 @@ def identify_dependent_variables(tensor, candidate_variables): def get_soft_update_op(update_fraction, including_nets, excluding_nets=None): """ + Builds the graph op to softly update the "old net" of policies and value_functions, as suggested in + `Link DDPG `_. It updates the :class:`tf.Variable` s in the old net, + :math:`\\theta'` with the :class:`tf.Variable` s in the current network, :math:`\\theta` as + :math:`\\theta' = \tau \\theta + (1 - \tau) \\theta'`. - :param including_nets: - :param excluding_nets: - :return: + :param update_fraction: A float in range :math:`[0, 1]`. Corresponding to the :math:`\tau` in the update equation. + :param including_nets: A list of policies and/or value_functions. All :class:`tf.Variable` s in these networks + are included for update. Shared Variables are updated only once in case of layer sharing among the networks. + :param excluding_nets: Optional. A list of policies and/or value_functions defaulting to ``None``. + All :class:`tf.Variable` s in these networks + are excluded from the update determined by ``including nets``. This is useful in existence of layer sharing + among networks and we only want to update the Variables in ``including_nets`` that are not shared. + + :return: A list of ops :func:`tf.assign` specifying the soft update. """ assert 0 < update_fraction < 1, 'Unrecommended update_fraction <=0 or >=1!' diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py deleted file mode 100644 index a1684ee..0000000 --- a/tianshou/data/batch.py +++ /dev/null @@ -1,235 +0,0 @@ -import numpy as np -import gc -import logging -from . import utils - -# TODO: Refactor with tf.train.slice_input_producer, tf.train.Coordinator, tf.train.QueueRunner -class Batch(object): - """ - class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy. - """ - - def __init__(self, env, pi, reward_processors, networks, render=False): # how to name the function? - """ - constructor - :param env: - :param pi: - :param reward_processors: list of functions to process reward - :param networks: list of networks to be optimized, so as to match data in feed_dict - """ - self._env = env - self._pi = pi - self.raw_data = {} - self.data = {} - - self.reward_processors = reward_processors - self.networks = networks - self.render = render - - self.required_placeholders = {} - for net in self.networks: - self.required_placeholders.update(net.managed_placeholders) - self.require_advantage = 'advantage' in self.required_placeholders.keys() - - self._is_first_collect = True - - def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, - process_reward=True, epsilon_greedy=0): # specify how many data to collect here, or fix it in __init__() - assert sum( - [num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!" - - if num_timesteps > 0: # YouQiaoben: finish this implementation, the following code are just from openai/baselines - t = 0 - ac = self._env.action_space.sample() # not used, just so we have the datatype - new = True # marks if we're on first timestep of an episode - if self.is_first_collect: - ob = self._env.reset() - self.is_first_collect = False - else: - ob = self.raw_data['observations'][0] # last observation! - - # Initialize history arrays - observations = np.array([ob for _ in range(num_timesteps)]) - rewards = np.zeros(num_timesteps, 'float32') - episode_start_flags = np.zeros(num_timesteps, 'int32') - actions = np.array([ac for _ in range(num_timesteps)]) - - for t in range(num_timesteps): - pass - - while True: - prevac = ac - ac, vpred = pi.act(stochastic, ob) - # Slight weirdness here because we need value function at time T - # before returning segment [0, T-1] so we get the correct - # terminal value - i = t % horizon - observations[i] = ob - vpreds[i] = vpred - episode_start_flags[i] = new - actions[i] = ac - prevacs[i] = prevac - - ob, rew, new, _ = self._env.step(ac) - rewards[i] = rew - - cur_ep_ret += rew - cur_ep_len += 1 - if new: - ep_rets.append(cur_ep_ret) - ep_lens.append(cur_ep_len) - cur_ep_ret = 0 - cur_ep_len = 0 - ob = self._env.reset() - t += 1 - - if num_episodes > 0: # YouQiaoben: fix memory growth, both del and gc.collect() fail - # initialize rawdata lists - if not self._is_first_collect: - del self.observations - del self.actions - del self.rewards - del self.episode_start_flags - - observations = [] - actions = [] - rewards = [] - episode_start_flags = [] - - # t_count = 0 - - for _ in range(num_episodes): - t_count = 0 - - ob = self._env.reset() - observations.append(ob) - episode_start_flags.append(True) - - while True: - # a simple implementation of epsilon greedy - if epsilon_greedy > 0 and np.random.random() < epsilon_greedy: - ac = np.random.randint(low = 0, high = self._env.action_space.n) - else: - ac = self._pi.act(ob, my_feed_dict) - actions.append(ac) - - if self.render: - self._env.render() - ob, reward, done, _ = self._env.step(ac) - rewards.append(reward) - - #t_count += 1 - #if t_count >= 100: # force episode stop, just to test if memory still grows - # break - - if done: # end of episode, discard s_T - # TODO: for num_timesteps collection, has to store terminal flag instead of start flag! - break - else: - observations.append(ob) - episode_start_flags.append(False) - - self.observations = np.array(observations) - self.actions = np.array(actions) - self.rewards = np.array(rewards) - self.episode_start_flags = np.array(episode_start_flags) - - del observations - del actions - del rewards - del episode_start_flags - - self.raw_data = {'observation': self.observations, 'action': self.actions, 'reward': self.rewards, - 'end_flag': self.episode_start_flags} - - self._is_first_collect = False - - if process_reward: - self.apply_advantage_estimation_function() - - gc.collect() - - def apply_advantage_estimation_function(self): - for processor in self.reward_processors: - self.data.update(processor(self.raw_data)) - - def next_batch(self, batch_size, standardize_advantage=True): - rand_idx = np.random.choice(self.raw_data['observation'].shape[0], batch_size) - - # maybe re-compute advantage here, but only on rand_idx - # but how to construct the feed_dict? - if self.online: - self.data_batch.update(self.apply_advantage_estimation_function(rand_idx)) - - - feed_dict = {} - for key, placeholder in self.required_placeholders.items(): - feed_dict[placeholder] = utils.get_batch(self, key, rand_idx) - - found, data_key = utils.internal_key_match(key, self.raw_data.keys()) - if found: - feed_dict[placeholder] = utils.get_batch(self.raw_data[data_key], rand_idx) # self.raw_data[data_key][rand_idx] - else: - found, data_key = utils.internal_key_match(key, self.data.keys()) - if found: - feed_dict[placeholder] = self.data[data_key][rand_idx] - - if not found: - raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name))) - - if standardize_advantage: - if self.require_advantage: - advantage_value = feed_dict[self.required_placeholders['advantage']] - advantage_mean = np.mean(advantage_value) - advantage_std = np.std(advantage_value) - if advantage_std < 1e-3: - logging.warning('advantage_std too small (< 1e-3) for advantage standardization. may cause numerical issues') - feed_dict[self.required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std - - # TODO: maybe move all advantage estimation functions to tf, as in tensorforce (though haven't - # understood tensorforce after reading) maybe tf.stop_gradient for targets/advantages - # this will simplify data collector as it only needs to collect raw data, (s, a, r, done) only - - return feed_dict - - # TODO: this will definitely be refactored with a proper logger - def statistics(self): - """ - compute the statistics of the current sampled paths - :return: - """ - rewards = self.raw_data['reward'] - episode_start_flags = self.raw_data['end_flag'] - num_timesteps = rewards.shape[0] - - returns = [] - episode_lengths = [] - max_return = 0 - num_episodes = 1 - episode_start_idx = 0 - for i in range(1, num_timesteps): - if episode_start_flags[i] or ( - i == num_timesteps - 1): # found the start of next episode or the end of all episodes - if episode_start_flags[i]: - num_episodes += 1 - if i < rewards.shape[0] - 1: - t = i - 1 - else: - t = i - Gt = 0 - episode_lengths.append(t - episode_start_idx) - while t >= episode_start_idx: - Gt += rewards[t] - t -= 1 - - returns.append(Gt) - if Gt > max_return: - max_return = Gt - episode_start_idx = i - - print('AverageReturn: {}'.format(np.mean(returns))) - print('StdReturn : {}'.format(np.std(returns))) - print('NumEpisodes : {}'.format(num_episodes)) - print('MinMaxReturns: {}..., {}'.format(np.sort(returns)[:3], np.sort(returns)[-3:])) - print('AverageLength: {}'.format(np.mean(episode_lengths))) - print('MinMaxLengths: {}..., {}'.format(np.sort(episode_lengths)[:3], np.sort(episode_lengths)[-3:])) diff --git a/tianshou/data/data_collector.py b/tianshou/data/data_collector.py index ab3e3b4..6469e6a 100644 --- a/tianshou/data/data_collector.py +++ b/tianshou/data/data_collector.py @@ -6,9 +6,17 @@ from .data_buffer.replay_buffer_base import ReplayBufferBase from .data_buffer.batch_set import BatchSet from .utils import internal_key_match + class DataCollector(object): """ - a utility class to manage the interaction between buffer and advantage_estimation + A utility class to manage the data flow during the interaction between the policy and the environment. + It stores data into ``data_buffer``, processes the reward signals and returns the feed_dict for tf graph running. + + :param env: + :param policy: + :param data_buffer: + :param process_functions: + :param managed_networks: """ def __init__(self, env, policy, data_buffer, process_functions, managed_networks): self.env = env