some more API docs

This commit is contained in:
haoshengzou 2018-04-15 11:46:46 +08:00
parent 9186dae6a3
commit 8c108174b6
5 changed files with 78 additions and 273 deletions

View File

@ -1,28 +0,0 @@
#TODO:
Separate actor and critic. (Important, we need to focus on that recently)
# policy
YongRen
### base, stochastic
follow OnehotCategorical to write Gaussian, can be in the same file as stochastic.py
### deterministic
not sure how to write, but should at least have act() method to interact with environment
referencing QValuePolicy in base.py, should have at least the listed methods.
# losses
TongzhengRen
seems to be direct python functions. Though the management of placeholders may require some discussion. also may write it in a functional form.
# policy, value_function
naming should be reconsidered. Perhaps use plural forms for all nouns

View File

@ -7,11 +7,26 @@ import numpy as np
class RandomProcess(object): class RandomProcess(object):
"""
Base class for random process for exploration in the environment.
"""
def reset_states(self): def reset_states(self):
"""
Reset the internal states, if any, of the random process. Does nothing by default.
"""
pass pass
class AnnealedGaussianProcess(RandomProcess): class AnnealedGaussianProcess(RandomProcess):
"""
Class for annealed Gaussian process, annealing the sigma in the Gaussian-like distribution along sampling.
At each timestep, the class samples from a Gaussian-like distribution.
:param mu: A float. Specifying the mean of the Gaussian-like distribution.
:param sigma: A float. Specifying the std of teh Gaussian-like distribution.
:param sigma_min: A float. Specifying the minimum std until which the annealing stops.
:param n_steps_annealing: An int. It specifies the total number of steps for which the annealing happens.
"""
def __init__(self, mu, sigma, sigma_min, n_steps_annealing): def __init__(self, mu, sigma, sigma_min, n_steps_annealing):
self.mu = mu self.mu = mu
self.sigma = sigma self.sigma = sigma
@ -28,11 +43,24 @@ class AnnealedGaussianProcess(RandomProcess):
@property @property
def current_sigma(self): def current_sigma(self):
"""The current sigma after potential annealing."""
sigma = max(self.sigma_min, self.m * float(self.n_steps) + self.c) sigma = max(self.sigma_min, self.m * float(self.n_steps) + self.c)
return sigma return sigma
class GaussianWhiteNoiseProcess(AnnealedGaussianProcess): class GaussianWhiteNoiseProcess(AnnealedGaussianProcess):
"""
Class for Gaussian white noise. At each timestep, the class samples from an exact Gaussian distribution.
It allows annealing in the std of the Gaussian, but the distribution is independent at different timesteps.
:param mu: A float defaulting to 0. Specifying the mean of the Gaussian-like distribution.
:param sigma: A float defaulting to 1. Specifying the std of the Gaussian-like distribution.
:param sigma_min: Optional. A float. Specifying the minimum std until which the annealing stops. It defaults to
``None`` where no annealing takes place.
:param n_steps_annealing: Optional. An int. It specifies the total number of steps for which the annealing happens.
Only effective when ``sigma_mean`` is not ``None``.
:param size: An int or tuple of ints. It corresponds to the shape of the action of the environment.
"""
def __init__(self, mu=0., sigma=1., sigma_min=None, n_steps_annealing=1000, size=1): def __init__(self, mu=0., sigma=1., sigma_min=None, n_steps_annealing=1000, size=1):
super(GaussianWhiteNoiseProcess, self).__init__(mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing) super(GaussianWhiteNoiseProcess, self).__init__(mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing)
self.size = size self.size = size
@ -42,8 +70,27 @@ class GaussianWhiteNoiseProcess(AnnealedGaussianProcess):
self.n_steps += 1 self.n_steps += 1
return sample return sample
# Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess): class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess):
"""
Class for Ornstein-Uhlenbeck Process, as used for exploration in DDPG. Implemented based on
http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab .
It basically is a temporal-correlated Gaussian process where the distribution at the current timestep depends on
the samples from the last timestep. It's not exactly Gaussian but still resembles Gaussian.
:param theta: A float. A special parameter for this process.
:param mu: A float. Another parameter of this process, but it's not exactly the mean of the distribution.
:param sigma: A float. Another parameter of this process. It acts like the std of the Gaussian-like distribution
to some extent.
:param dt: A float. The time interval to simulate this process discretely, as the process is mathematically defined
to be a continuous one.
:param x0: Optional. A float. The initial value of "the samples from the last timestep" so as to draw the first
sample. It defaults to zero.
:param size: An int or tuple of ints. It corresponds to the shape of the action of the environment.
:param sigma_min: Optional. A float. Specifying the minimum std until which the annealing stops. It defaults to
``None`` where no annealing takes place.
:param n_steps_annealing: An int. It specifies the total number of steps for which the annealing happens.
"""
def __init__(self, theta, mu=0., sigma=1., dt=1e-2, x0=None, size=1, sigma_min=None, n_steps_annealing=1000): def __init__(self, theta, mu=0., sigma=1., dt=1e-2, x0=None, size=1, sigma_min=None, n_steps_annealing=1000):
super(OrnsteinUhlenbeckProcess, self).__init__( super(OrnsteinUhlenbeckProcess, self).__init__(
mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing) mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing)
@ -55,7 +102,8 @@ class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess):
self.reset_states() self.reset_states()
def sample(self): def sample(self):
x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size) x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + \
self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size)
self.x_prev = x self.x_prev = x
self.n_steps += 1 self.n_steps += 1
return x return x

View File

@ -3,10 +3,12 @@ import tensorflow as tf
def identify_dependent_variables(tensor, candidate_variables): def identify_dependent_variables(tensor, candidate_variables):
""" """
identify the variables that `tensor` depends on Identify and return the variables in ``candidate_variables`` that ``tensor`` depends on.
:param tensor: A Tensor.
:param candidate_variables: A list of Variables. :param tensor: A Tensor. The target Tensor to identify dependency.
:return: A list of variables in `candidate variables` that has effect on `tensor` :param candidate_variables: A list of :class:`tf.Variable` s. The candidate Variables to identify dependency.
:return: A list of :class:`tf.Variable` s in ``candidate variables`` that has effect on ``tensor``.
""" """
grads = tf.gradients(tensor, candidate_variables) grads = tf.gradients(tensor, candidate_variables)
return [var for var, grad in zip(candidate_variables, grads) if grad is not None] return [var for var, grad in zip(candidate_variables, grads) if grad is not None]
@ -14,10 +16,20 @@ def identify_dependent_variables(tensor, candidate_variables):
def get_soft_update_op(update_fraction, including_nets, excluding_nets=None): def get_soft_update_op(update_fraction, including_nets, excluding_nets=None):
""" """
Builds the graph op to softly update the "old net" of policies and value_functions, as suggested in
`Link DDPG <https://arxiv.org/pdf/1509.02971.pdf>`_. It updates the :class:`tf.Variable` s in the old net,
:math:`\\theta'` with the :class:`tf.Variable` s in the current network, :math:`\\theta` as
:math:`\\theta' = \tau \\theta + (1 - \tau) \\theta'`.
:param including_nets: :param update_fraction: A float in range :math:`[0, 1]`. Corresponding to the :math:`\tau` in the update equation.
:param excluding_nets: :param including_nets: A list of policies and/or value_functions. All :class:`tf.Variable` s in these networks
:return: are included for update. Shared Variables are updated only once in case of layer sharing among the networks.
:param excluding_nets: Optional. A list of policies and/or value_functions defaulting to ``None``.
All :class:`tf.Variable` s in these networks
are excluded from the update determined by ``including nets``. This is useful in existence of layer sharing
among networks and we only want to update the Variables in ``including_nets`` that are not shared.
:return: A list of ops :func:`tf.assign` specifying the soft update.
""" """
assert 0 < update_fraction < 1, 'Unrecommended update_fraction <=0 or >=1!' assert 0 < update_fraction < 1, 'Unrecommended update_fraction <=0 or >=1!'

View File

@ -1,235 +0,0 @@
import numpy as np
import gc
import logging
from . import utils
# TODO: Refactor with tf.train.slice_input_producer, tf.train.Coordinator, tf.train.QueueRunner
class Batch(object):
"""
class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy.
"""
def __init__(self, env, pi, reward_processors, networks, render=False): # how to name the function?
"""
constructor
:param env:
:param pi:
:param reward_processors: list of functions to process reward
:param networks: list of networks to be optimized, so as to match data in feed_dict
"""
self._env = env
self._pi = pi
self.raw_data = {}
self.data = {}
self.reward_processors = reward_processors
self.networks = networks
self.render = render
self.required_placeholders = {}
for net in self.networks:
self.required_placeholders.update(net.managed_placeholders)
self.require_advantage = 'advantage' in self.required_placeholders.keys()
self._is_first_collect = True
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={},
process_reward=True, epsilon_greedy=0): # specify how many data to collect here, or fix it in __init__()
assert sum(
[num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!"
if num_timesteps > 0: # YouQiaoben: finish this implementation, the following code are just from openai/baselines
t = 0
ac = self._env.action_space.sample() # not used, just so we have the datatype
new = True # marks if we're on first timestep of an episode
if self.is_first_collect:
ob = self._env.reset()
self.is_first_collect = False
else:
ob = self.raw_data['observations'][0] # last observation!
# Initialize history arrays
observations = np.array([ob for _ in range(num_timesteps)])
rewards = np.zeros(num_timesteps, 'float32')
episode_start_flags = np.zeros(num_timesteps, 'int32')
actions = np.array([ac for _ in range(num_timesteps)])
for t in range(num_timesteps):
pass
while True:
prevac = ac
ac, vpred = pi.act(stochastic, ob)
# Slight weirdness here because we need value function at time T
# before returning segment [0, T-1] so we get the correct
# terminal value
i = t % horizon
observations[i] = ob
vpreds[i] = vpred
episode_start_flags[i] = new
actions[i] = ac
prevacs[i] = prevac
ob, rew, new, _ = self._env.step(ac)
rewards[i] = rew
cur_ep_ret += rew
cur_ep_len += 1
if new:
ep_rets.append(cur_ep_ret)
ep_lens.append(cur_ep_len)
cur_ep_ret = 0
cur_ep_len = 0
ob = self._env.reset()
t += 1
if num_episodes > 0: # YouQiaoben: fix memory growth, both del and gc.collect() fail
# initialize rawdata lists
if not self._is_first_collect:
del self.observations
del self.actions
del self.rewards
del self.episode_start_flags
observations = []
actions = []
rewards = []
episode_start_flags = []
# t_count = 0
for _ in range(num_episodes):
t_count = 0
ob = self._env.reset()
observations.append(ob)
episode_start_flags.append(True)
while True:
# a simple implementation of epsilon greedy
if epsilon_greedy > 0 and np.random.random() < epsilon_greedy:
ac = np.random.randint(low = 0, high = self._env.action_space.n)
else:
ac = self._pi.act(ob, my_feed_dict)
actions.append(ac)
if self.render:
self._env.render()
ob, reward, done, _ = self._env.step(ac)
rewards.append(reward)
#t_count += 1
#if t_count >= 100: # force episode stop, just to test if memory still grows
# break
if done: # end of episode, discard s_T
# TODO: for num_timesteps collection, has to store terminal flag instead of start flag!
break
else:
observations.append(ob)
episode_start_flags.append(False)
self.observations = np.array(observations)
self.actions = np.array(actions)
self.rewards = np.array(rewards)
self.episode_start_flags = np.array(episode_start_flags)
del observations
del actions
del rewards
del episode_start_flags
self.raw_data = {'observation': self.observations, 'action': self.actions, 'reward': self.rewards,
'end_flag': self.episode_start_flags}
self._is_first_collect = False
if process_reward:
self.apply_advantage_estimation_function()
gc.collect()
def apply_advantage_estimation_function(self):
for processor in self.reward_processors:
self.data.update(processor(self.raw_data))
def next_batch(self, batch_size, standardize_advantage=True):
rand_idx = np.random.choice(self.raw_data['observation'].shape[0], batch_size)
# maybe re-compute advantage here, but only on rand_idx
# but how to construct the feed_dict?
if self.online:
self.data_batch.update(self.apply_advantage_estimation_function(rand_idx))
feed_dict = {}
for key, placeholder in self.required_placeholders.items():
feed_dict[placeholder] = utils.get_batch(self, key, rand_idx)
found, data_key = utils.internal_key_match(key, self.raw_data.keys())
if found:
feed_dict[placeholder] = utils.get_batch(self.raw_data[data_key], rand_idx) # self.raw_data[data_key][rand_idx]
else:
found, data_key = utils.internal_key_match(key, self.data.keys())
if found:
feed_dict[placeholder] = self.data[data_key][rand_idx]
if not found:
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
if standardize_advantage:
if self.require_advantage:
advantage_value = feed_dict[self.required_placeholders['advantage']]
advantage_mean = np.mean(advantage_value)
advantage_std = np.std(advantage_value)
if advantage_std < 1e-3:
logging.warning('advantage_std too small (< 1e-3) for advantage standardization. may cause numerical issues')
feed_dict[self.required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std
# TODO: maybe move all advantage estimation functions to tf, as in tensorforce (though haven't
# understood tensorforce after reading) maybe tf.stop_gradient for targets/advantages
# this will simplify data collector as it only needs to collect raw data, (s, a, r, done) only
return feed_dict
# TODO: this will definitely be refactored with a proper logger
def statistics(self):
"""
compute the statistics of the current sampled paths
:return:
"""
rewards = self.raw_data['reward']
episode_start_flags = self.raw_data['end_flag']
num_timesteps = rewards.shape[0]
returns = []
episode_lengths = []
max_return = 0
num_episodes = 1
episode_start_idx = 0
for i in range(1, num_timesteps):
if episode_start_flags[i] or (
i == num_timesteps - 1): # found the start of next episode or the end of all episodes
if episode_start_flags[i]:
num_episodes += 1
if i < rewards.shape[0] - 1:
t = i - 1
else:
t = i
Gt = 0
episode_lengths.append(t - episode_start_idx)
while t >= episode_start_idx:
Gt += rewards[t]
t -= 1
returns.append(Gt)
if Gt > max_return:
max_return = Gt
episode_start_idx = i
print('AverageReturn: {}'.format(np.mean(returns)))
print('StdReturn : {}'.format(np.std(returns)))
print('NumEpisodes : {}'.format(num_episodes))
print('MinMaxReturns: {}..., {}'.format(np.sort(returns)[:3], np.sort(returns)[-3:]))
print('AverageLength: {}'.format(np.mean(episode_lengths)))
print('MinMaxLengths: {}..., {}'.format(np.sort(episode_lengths)[:3], np.sort(episode_lengths)[-3:]))

View File

@ -6,9 +6,17 @@ from .data_buffer.replay_buffer_base import ReplayBufferBase
from .data_buffer.batch_set import BatchSet from .data_buffer.batch_set import BatchSet
from .utils import internal_key_match from .utils import internal_key_match
class DataCollector(object): class DataCollector(object):
""" """
a utility class to manage the interaction between buffer and advantage_estimation A utility class to manage the data flow during the interaction between the policy and the environment.
It stores data into ``data_buffer``, processes the reward signals and returns the feed_dict for tf graph running.
:param env:
:param policy:
:param data_buffer:
:param process_functions:
:param managed_networks:
""" """
def __init__(self, env, policy, data_buffer, process_functions, managed_networks): def __init__(self, env, policy, data_buffer, process_functions, managed_networks):
self.env = env self.env = env