fix naming and comments of coding style, delete .json
This commit is contained in:
parent
0da31faa94
commit
a00b930c2c
3
.gitignore
vendored
3
.gitignore
vendored
@ -6,5 +6,4 @@ parameters
|
||||
*.sublime*
|
||||
checkpoints
|
||||
checkpoints_origin
|
||||
|
||||
|
||||
*.json
|
||||
|
@ -44,9 +44,14 @@ Tianshou(天授) is a reinforcement learning platform. The following image illus
|
||||
|
||||
## About coding style
|
||||
|
||||
You can follow [google python coding style](https://google.github.io/styleguide/pyguide.html)
|
||||
Please follow [google python coding style](https://google.github.io/styleguide/pyguide.html)
|
||||
|
||||
Files should all be named with lower case letters and underline.
|
||||
|
||||
Try to use full names. Don't use too many abbrevations for class/function/variable names except common abbrevations (such as `num` for number, `dim` for dimension, `env` for environment, `op` for operation). For now we use `pi` to refer to the policy in examples/ppo_example.py.
|
||||
|
||||
The """xxx""" comment should be written right after class/function. Also comment the part that's not intuitive during the code. We must comment, but for now we don't need to polish them.
|
||||
|
||||
The file should all be named with lower case letters and underline.
|
||||
|
||||
## TODO
|
||||
Search based method parallel.
|
||||
|
@ -8,64 +8,64 @@ import gym
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
import tianshou.core.losses as losses
|
||||
from tianshou.data.Batch import Batch
|
||||
import tianshou.data.adv_estimate as adv_estimate
|
||||
from tianshou.data.batch import Batch
|
||||
import tianshou.data.advantage_estimation as advantage_estimation
|
||||
import tianshou.core.policy as policy
|
||||
|
||||
|
||||
def policy_net(obs, act_dim, scope=None):
|
||||
def policy_net(observation, action_dim, scope=None):
|
||||
"""
|
||||
Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf
|
||||
|
||||
:param obs: Placeholder for the observation. A tensor of shape (bs, x, y, channels)
|
||||
:param act_dim: int. The number of actions.
|
||||
:param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels)
|
||||
:param action_dim: int. The number of actions.
|
||||
:param scope: str. Specifying the scope of the variables.
|
||||
"""
|
||||
# with tf.variable_scope(scope):
|
||||
net = tf.layers.conv2d(obs, 16, 8, 4, 'valid', activation=tf.nn.relu)
|
||||
net = tf.layers.conv2d(observation, 16, 8, 4, 'valid', activation=tf.nn.relu)
|
||||
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
|
||||
net = tf.layers.flatten(net)
|
||||
net = tf.layers.dense(net, 256, activation=tf.nn.relu)
|
||||
|
||||
act_logits = tf.layers.dense(net, act_dim)
|
||||
act_logits = tf.layers.dense(net, action_dim)
|
||||
|
||||
return act_logits
|
||||
|
||||
|
||||
if __name__ == '__main__': # a clean version with only policy net, no value net
|
||||
env = gym.make('PongNoFrameskip-v4')
|
||||
obs_dim = env.observation_space.shape
|
||||
act_dim = env.action_space.n
|
||||
observation_dim = env.observation_space.shape
|
||||
action_dim = env.action_space.n
|
||||
|
||||
clip_param = 0.2
|
||||
nb_batches = 2
|
||||
num_batches = 2
|
||||
|
||||
# 1. build network with pure tf
|
||||
obs = tf.placeholder(tf.float32, shape=(None,) + obs_dim) # network input
|
||||
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input
|
||||
|
||||
with tf.variable_scope('pi'):
|
||||
act_logits = policy_net(obs, act_dim, 'pi')
|
||||
action_logits = policy_net(observation, action_dim, 'pi')
|
||||
train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
|
||||
with tf.variable_scope('pi_old'):
|
||||
act_logits_old = policy_net(obs, act_dim, 'pi_old')
|
||||
action_logits_old = policy_net(observation, action_dim, 'pi_old')
|
||||
|
||||
# 2. build losses, optimizers
|
||||
pi = policy.OnehotCategorical(act_logits, obs_placeholder=obs) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc.
|
||||
pi = policy.OnehotCategorical(action_logits, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc.
|
||||
# for continuous action space, you may need to change an environment to run
|
||||
pi_old = policy.OnehotCategorical(act_logits_old, obs_placeholder=obs)
|
||||
pi_old = policy.OnehotCategorical(action_logits_old, observation_placeholder=observation)
|
||||
|
||||
act = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions
|
||||
Dgrad = tf.placeholder(dtype=tf.float32, shape=[None]) # values used in the Gradients
|
||||
action = tf.placeholder(dtype=tf.int32, shape=[None]) # batch of integer actions
|
||||
advantage = tf.placeholder(dtype=tf.float32, shape=[None]) # advantage values used in the Gradients
|
||||
|
||||
ppo_loss_clip = losses.ppo_clip(act, Dgrad, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict
|
||||
ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict
|
||||
|
||||
total_loss = ppo_loss_clip
|
||||
optimizer = tf.train.AdamOptimizer(1e-3)
|
||||
train_op = optimizer.minimize(total_loss, var_list=train_var_list)
|
||||
|
||||
# 3. define data collection
|
||||
training_data = Batch(env, pi, adv_estimate.full_return) # YouQiaoben: finish and polish Batch, adv_estimate.gae_lambda as in PPO paper
|
||||
# ShihongSong: Replay(env, pi, adv_estimate.target_network), use your ReplayMemory, interact as follows. Simplify your adv_estimate.dqn to run before YongRen's DQN
|
||||
training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper
|
||||
# ShihongSong: Replay(env, pi, advantage_estimation.target_network), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
|
||||
# maybe a dict to manage the elements to be collected
|
||||
|
||||
# 4. start training
|
||||
@ -81,9 +81,9 @@ if __name__ == '__main__': # a clean version with only policy net, no value net
|
||||
print('Collected {} times.'.format(collection_count))
|
||||
|
||||
# update network
|
||||
for _ in range(nb_batches):
|
||||
for _ in range(num_batches):
|
||||
data = training_data.next_batch(64) # YouQiaoben, ShihongSong
|
||||
# TODO: auto managing of the placeholders? or add this to params of data.Batch
|
||||
sess.run(train_op, feed_dict={obs: data['obs'], act: data['acs'], Dgrad: data['Gts']})
|
||||
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']})
|
||||
minibatch_count += 1
|
||||
print('Trained {} minibatches.'.format(minibatch_count))
|
@ -1,5 +0,0 @@
|
||||
{
|
||||
"global_description": "read by Environment, Neural Network, and MCTS",
|
||||
"state_space": " ",
|
||||
"action_space": " "
|
||||
}
|
@ -1,29 +1,24 @@
|
||||
import tensorflow as tf
|
||||
import baselines.common.tf_util as U
|
||||
|
||||
|
||||
def ppo_clip(sampled_action, Dgrad, clip_param, pi, pi_old):
|
||||
def ppo_clip(sampled_action, advantage, clip_param, pi, pi_old):
|
||||
"""
|
||||
the clip loss in ppo paper
|
||||
|
||||
:param sampled_action: placeholder of sampled actions during interaction with the environment
|
||||
:param advantage: placeholder of estimated advantage values.
|
||||
:param clip param: float or Tensor of type float.
|
||||
:param pi: current `policy` to be optimized
|
||||
:param pi_old: old `policy` for computing the ppo loss as in Eqn. (7) in the paper
|
||||
"""
|
||||
|
||||
log_pi_act = pi.log_prob(sampled_action)
|
||||
log_pi_old_act = pi_old.log_prob(sampled_action)
|
||||
ratio = tf.exp(log_pi_act - log_pi_old_act)
|
||||
clipped_ratio = tf.clip_by_value(ratio, 1. - clip_param, 1. + clip_param)
|
||||
ppo_clip_loss = -tf.reduce_mean(tf.minimum(ratio * Dgrad, clipped_ratio * Dgrad))
|
||||
ppo_clip_loss = -tf.reduce_mean(tf.minimum(ratio * advantage, clipped_ratio * advantage))
|
||||
return ppo_clip_loss
|
||||
|
||||
|
||||
def L_VF(Gt, pi, St): # TODO: do we really have to specify St, or it's implicit in policy/value net
|
||||
return U.mean(tf.square(pi.vpred - Gt))
|
||||
|
||||
|
||||
def entropy_reg(pi):
|
||||
return - U.mean(pi.pd.entropy())
|
||||
|
||||
|
||||
def KL_diff(pi, pi_old):
|
||||
kloldnew = pi_old.pd.kl(pi.pd)
|
||||
meankl = U.mean(kloldnew)
|
||||
return meankl
|
||||
|
||||
|
||||
def vanilla_policy_gradient():
|
||||
pass
|
||||
|
@ -83,14 +83,14 @@ class StochasticPolicy(object):
|
||||
act_dtype,
|
||||
param_dtype,
|
||||
is_continuous,
|
||||
obs_placeholder,
|
||||
observation_placeholder,
|
||||
group_ndims=0, # maybe useful for repeat_action
|
||||
**kwargs):
|
||||
|
||||
self._act_dtype = act_dtype
|
||||
self._param_dtype = param_dtype
|
||||
self._is_continuous = is_continuous
|
||||
self._obs_placeholder = obs_placeholder
|
||||
self._observation_placeholder = observation_placeholder
|
||||
if isinstance(group_ndims, int):
|
||||
if group_ndims < 0:
|
||||
raise ValueError("group_ndims must be non-negative.")
|
||||
|
@ -39,7 +39,7 @@ class OnehotCategorical(StochasticPolicy):
|
||||
`[i, j, ..., k, :]` is a one-hot vector of the selected category.
|
||||
"""
|
||||
|
||||
def __init__(self, logits, obs_placeholder, dtype=None, group_ndims=0, **kwargs):
|
||||
def __init__(self, logits, observation_placeholder, dtype=None, group_ndims=0, **kwargs):
|
||||
self._logits = tf.convert_to_tensor(logits)
|
||||
|
||||
if dtype is None:
|
||||
@ -53,7 +53,7 @@ class OnehotCategorical(StochasticPolicy):
|
||||
act_dtype=dtype,
|
||||
param_dtype=self._logits.dtype,
|
||||
is_continuous=False,
|
||||
obs_placeholder=obs_placeholder,
|
||||
observation_placeholder=observation_placeholder,
|
||||
group_ndims=group_ndims,
|
||||
**kwargs)
|
||||
|
||||
@ -69,7 +69,7 @@ class OnehotCategorical(StochasticPolicy):
|
||||
|
||||
def _act(self, observation):
|
||||
sess = tf.get_default_session() # TODO: this may be ugly. also maybe huge problem when parallel
|
||||
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), feed_dict={self._obs_placeholder: observation[None]})
|
||||
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), feed_dict={self._observation_placeholder: observation[None]})
|
||||
|
||||
sampled_action = sampled_action[0, 0]
|
||||
|
||||
|
@ -1,4 +0,0 @@
|
||||
{
|
||||
"state" : "10",
|
||||
"mask" : "1000"
|
||||
}
|
@ -1,128 +0,0 @@
|
||||
import numpy as np
|
||||
import gc
|
||||
|
||||
|
||||
# TODO: Refactor with tf.train.slice_input_producer, tf.train.Coordinator, tf.train.QueueRunner
|
||||
class Batch(object):
|
||||
"""
|
||||
class for batch datasets. Collect multiple states (actions, rewards, etc.) on-policy.
|
||||
"""
|
||||
|
||||
def __init__(self, env, pi, adv_estimation_func): # how to name the function?
|
||||
self.env = env
|
||||
self.pi = pi
|
||||
self.adv_estimation_func = adv_estimation_func
|
||||
self.is_first_collect = True
|
||||
|
||||
|
||||
def collect(self, num_timesteps=0, num_episodes=0, apply_func=True): # specify how many data to collect here, or fix it in __init__()
|
||||
assert sum([num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!"
|
||||
|
||||
if num_timesteps > 0: # YouQiaoben: finish this implementation, the following code are just from openai/baselines
|
||||
t = 0
|
||||
ac = self.env.action_space.sample() # not used, just so we have the datatype
|
||||
new = True # marks if we're on first timestep of an episode
|
||||
if self.is_first_collect:
|
||||
ob = self.env.reset()
|
||||
self.is_first_collect = False
|
||||
else:
|
||||
ob = self.raw_data['obs'][0] # last observation!
|
||||
|
||||
# Initialize history arrays
|
||||
obs = np.array([ob for _ in range(num_timesteps)])
|
||||
rews = np.zeros(num_timesteps, 'float32')
|
||||
news = np.zeros(num_timesteps, 'int32')
|
||||
acs = np.array([ac for _ in range(num_timesteps)])
|
||||
|
||||
for t in range(num_timesteps):
|
||||
pass
|
||||
|
||||
while True:
|
||||
prevac = ac
|
||||
ac, vpred = pi.act(stochastic, ob)
|
||||
# Slight weirdness here because we need value function at time T
|
||||
# before returning segment [0, T-1] so we get the correct
|
||||
# terminal value
|
||||
i = t % horizon
|
||||
obs[i] = ob
|
||||
vpreds[i] = vpred
|
||||
news[i] = new
|
||||
acs[i] = ac
|
||||
prevacs[i] = prevac
|
||||
|
||||
ob, rew, new, _ = env.step(ac)
|
||||
rews[i] = rew
|
||||
|
||||
cur_ep_ret += rew
|
||||
cur_ep_len += 1
|
||||
if new:
|
||||
ep_rets.append(cur_ep_ret)
|
||||
ep_lens.append(cur_ep_len)
|
||||
cur_ep_ret = 0
|
||||
cur_ep_len = 0
|
||||
ob = env.reset()
|
||||
t += 1
|
||||
|
||||
if num_episodes > 0: # YouQiaoben: fix memory growth, both del and gc.collect() fail
|
||||
# initialize rawdata lists
|
||||
if not self.is_first_collect:
|
||||
del self.obs
|
||||
del self.acs
|
||||
del self.rews
|
||||
del self.news
|
||||
|
||||
obs = []
|
||||
acs = []
|
||||
rews = []
|
||||
news = []
|
||||
|
||||
t_count = 0
|
||||
|
||||
for e in range(num_episodes):
|
||||
ob = self.env.reset()
|
||||
obs.append(ob)
|
||||
news.append(True)
|
||||
|
||||
while True:
|
||||
ac = self.pi.act(ob)
|
||||
acs.append(ac)
|
||||
|
||||
ob, rew, done, _ = self.env.step(ac)
|
||||
rews.append(rew)
|
||||
|
||||
t_count += 1
|
||||
if t_count >= 200: # force episode stop
|
||||
break
|
||||
|
||||
if done: # end of episode, discard s_T
|
||||
break
|
||||
else:
|
||||
obs.append(ob)
|
||||
news.append(False)
|
||||
|
||||
self.obs = np.array(obs)
|
||||
self.acs = np.array(acs)
|
||||
self.rews = np.array(rews)
|
||||
self.news = np.array(news)
|
||||
|
||||
del obs
|
||||
del acs
|
||||
del rews
|
||||
del news
|
||||
|
||||
self.raw_data = {'obs': self.obs, 'acs': self.acs, 'rews': self.rews, 'news': self.news}
|
||||
|
||||
self.is_first_collect = False
|
||||
|
||||
if apply_func:
|
||||
self.apply_adv_estimation_func()
|
||||
|
||||
gc.collect()
|
||||
|
||||
def apply_adv_estimation_func(self):
|
||||
self.data = self.adv_estimation_func(self.raw_data)
|
||||
|
||||
def next_batch(self, batch_size): # YouQiaoben: referencing other iterate over batches
|
||||
rand_idx = np.random.choice(self.data['obs'].shape[0], batch_size)
|
||||
return {key: value[rand_idx] for key, value in self.data.items()}
|
||||
|
@ -1,37 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def full_return(raw_data):
|
||||
"""
|
||||
naively compute full return
|
||||
:param raw_data: dict of specified keys and values.
|
||||
"""
|
||||
obs = raw_data['obs']
|
||||
acs = raw_data['acs']
|
||||
rews = raw_data['rews']
|
||||
news = raw_data['news']
|
||||
num_timesteps = rews.shape[0]
|
||||
|
||||
data = {}
|
||||
data['obs'] = obs
|
||||
data['acs'] = acs
|
||||
|
||||
Gts = rews.copy()
|
||||
episode_start_idx = 0
|
||||
for i in range(1, num_timesteps):
|
||||
if news[i] or (i == num_timesteps - 1): # found one full episode
|
||||
if i < rews.shape[0] - 1:
|
||||
t = i - 1
|
||||
else:
|
||||
t = i
|
||||
Gt = 0
|
||||
while t >= episode_start_idx:
|
||||
Gt += rews[t]
|
||||
Gts[t] = Gt
|
||||
t -= 1
|
||||
|
||||
episode_start_idx = i
|
||||
|
||||
data['Gts'] = Gts
|
||||
|
||||
return data
|
37
tianshou/data/advantage_estimation.py
Normal file
37
tianshou/data/advantage_estimation.py
Normal file
@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def full_return(raw_data):
|
||||
"""
|
||||
naively compute full return
|
||||
:param raw_data: dict of specified keys and values.
|
||||
"""
|
||||
observations = raw_data['observations']
|
||||
actions = raw_data['actions']
|
||||
rewards = raw_data['rewards']
|
||||
episode_start_flags = raw_data['episode_start_flags']
|
||||
num_timesteps = rewards.shape[0]
|
||||
|
||||
data = {}
|
||||
data['observations'] = observations
|
||||
data['actions'] = actions
|
||||
|
||||
returns = rewards.copy()
|
||||
episode_start_idx = 0
|
||||
for i in range(1, num_timesteps):
|
||||
if episode_start_flags[i] or (i == num_timesteps - 1): # found the start of next episode or the end of all episodes
|
||||
if i < rewards.shape[0] - 1:
|
||||
t = i - 1
|
||||
else:
|
||||
t = i
|
||||
Gt = 0
|
||||
while t >= episode_start_idx:
|
||||
Gt += rewards[t]
|
||||
returns[t] = Gt
|
||||
t -= 1
|
||||
|
||||
episode_start_idx = i
|
||||
|
||||
data['returns'] = returns
|
||||
|
||||
return data
|
128
tianshou/data/batch.py
Normal file
128
tianshou/data/batch.py
Normal file
@ -0,0 +1,128 @@
|
||||
import numpy as np
|
||||
import gc
|
||||
|
||||
|
||||
# TODO: Refactor with tf.train.slice_input_producer, tf.train.Coordinator, tf.train.QueueRunner
|
||||
class Batch(object):
|
||||
"""
|
||||
class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy.
|
||||
"""
|
||||
|
||||
def __init__(self, env, pi, advantage_estimation_function): # how to name the function?
|
||||
self._env = env
|
||||
self._pi = pi
|
||||
self._advantage_estimation_function = advantage_estimation_function
|
||||
self._is_first_collect = True
|
||||
|
||||
|
||||
def collect(self, num_timesteps=0, num_episodes=0, apply_function=True): # specify how many data to collect here, or fix it in __init__()
|
||||
assert sum([num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!"
|
||||
|
||||
if num_timesteps > 0: # YouQiaoben: finish this implementation, the following code are just from openai/baselines
|
||||
t = 0
|
||||
ac = self.env.action_space.sample() # not used, just so we have the datatype
|
||||
new = True # marks if we're on first timestep of an episode
|
||||
if self.is_first_collect:
|
||||
ob = self.env.reset()
|
||||
self.is_first_collect = False
|
||||
else:
|
||||
ob = self.raw_data['observations'][0] # last observation!
|
||||
|
||||
# Initialize history arrays
|
||||
observations = np.array([ob for _ in range(num_timesteps)])
|
||||
rewards = np.zeros(num_timesteps, 'float32')
|
||||
episode_start_flags = np.zeros(num_timesteps, 'int32')
|
||||
actions = np.array([ac for _ in range(num_timesteps)])
|
||||
|
||||
for t in range(num_timesteps):
|
||||
pass
|
||||
|
||||
while True:
|
||||
prevac = ac
|
||||
ac, vpred = pi.act(stochastic, ob)
|
||||
# Slight weirdness here because we need value function at time T
|
||||
# before returning segment [0, T-1] so we get the correct
|
||||
# terminal value
|
||||
i = t % horizon
|
||||
observations[i] = ob
|
||||
vpreds[i] = vpred
|
||||
episode_start_flags[i] = new
|
||||
actions[i] = ac
|
||||
prevacs[i] = prevac
|
||||
|
||||
ob, rew, new, _ = env.step(ac)
|
||||
rewards[i] = rew
|
||||
|
||||
cur_ep_ret += rew
|
||||
cur_ep_len += 1
|
||||
if new:
|
||||
ep_rets.append(cur_ep_ret)
|
||||
ep_lens.append(cur_ep_len)
|
||||
cur_ep_ret = 0
|
||||
cur_ep_len = 0
|
||||
ob = env.reset()
|
||||
t += 1
|
||||
|
||||
if num_episodes > 0: # YouQiaoben: fix memory growth, both del and gc.collect() fail
|
||||
# initialize rawdata lists
|
||||
if not self._is_first_collect:
|
||||
del self.observations
|
||||
del self.actions
|
||||
del self.rewards
|
||||
del self.episode_start_flags
|
||||
|
||||
observations = []
|
||||
actions = []
|
||||
rewards = []
|
||||
episode_start_flags = []
|
||||
|
||||
t_count = 0
|
||||
|
||||
for _ in range(num_episodes):
|
||||
ob = self._env.reset()
|
||||
observations.append(ob)
|
||||
episode_start_flags.append(True)
|
||||
|
||||
while True:
|
||||
ac = self._pi.act(ob)
|
||||
actions.append(ac)
|
||||
|
||||
ob, reward, done, _ = self._env.step(ac)
|
||||
rewards.append(reward)
|
||||
|
||||
t_count += 1
|
||||
if t_count >= 200: # force episode stop, just to test if memory still grows
|
||||
break
|
||||
|
||||
if done: # end of episode, discard s_T
|
||||
break
|
||||
else:
|
||||
observations.append(ob)
|
||||
episode_start_flags.append(False)
|
||||
|
||||
self.observations = np.array(observations)
|
||||
self.actions = np.array(actions)
|
||||
self.rewards = np.array(rewards)
|
||||
self.episode_start_flags = np.array(episode_start_flags)
|
||||
|
||||
del observations
|
||||
del actions
|
||||
del rewards
|
||||
del episode_start_flags
|
||||
|
||||
self.raw_data = {'observations': self.observations, 'actions': self.actions, 'rewards': self.rewards, 'episode_start_flags': self.episode_start_flags}
|
||||
|
||||
self._is_first_collect = False
|
||||
|
||||
if apply_function:
|
||||
self.apply_advantage_estimation_function()
|
||||
|
||||
gc.collect()
|
||||
|
||||
def apply_advantage_estimation_function(self):
|
||||
self.data = self._advantage_estimation_function(self.raw_data)
|
||||
|
||||
def next_batch(self, batch_size): # YouQiaoben: referencing other iterate over batches
|
||||
rand_idx = np.random.choice(self.data['observations'].shape[0], batch_size)
|
||||
return {key: value[rand_idx] for key, value in self.data.items()}
|
||||
|
Loading…
x
Reference in New Issue
Block a user