145 lines
5.0 KiB
Python
145 lines
5.0 KiB
Python
import gc
|
|
import numpy as np
|
|
import logging
|
|
|
|
from .base import DataBufferBase
|
|
|
|
__all__ = [
|
|
'BatchSet'
|
|
]
|
|
|
|
|
|
STATE = 0
|
|
ACTION = 1
|
|
REWARD = 2
|
|
DONE = 3
|
|
|
|
|
|
class BatchSet(DataBufferBase):
|
|
"""
|
|
Class for batched dataset as used in on-policy algorithms, where a batch of data is first collected
|
|
with the current policy, several optimization steps are then conducted on this batch of data and the
|
|
data are then discarded and collected again.
|
|
|
|
:param nstep: An int defaulting to 1. The number of timesteps to lookahead for temporal difference computation.
|
|
Only continuous data pieces longer than this number or already terminated ones are
|
|
considered valid data points.
|
|
"""
|
|
def __init__(self, nstep=1):
|
|
self.nstep = nstep # RL has to look ahead at least one timestep
|
|
|
|
self.data = [[]]
|
|
self.index = [[]]
|
|
self.candidate_index = 0
|
|
|
|
self.size = 0 # number of valid data points (not frames)
|
|
|
|
self.index_lengths = [0] # for sampling
|
|
|
|
def add(self, frame):
|
|
"""
|
|
Adds one frame of data to the buffer.
|
|
|
|
:param frame: A tuple of (observation, action, reward, done_flag).
|
|
"""
|
|
self.data[-1].append(frame)
|
|
|
|
has_enough_frames = len(self.data[-1]) > self.nstep
|
|
if frame[DONE]: # episode terminates, all trailing frames become valid data points
|
|
trailing_index = list(range(self.candidate_index, len(self.data[-1])))
|
|
self.index[-1] += trailing_index
|
|
self.size += len(trailing_index)
|
|
self.index_lengths[-1] += len(trailing_index)
|
|
|
|
# prepare for the next episode
|
|
self.data.append([])
|
|
self.index.append([])
|
|
self.candidate_index = 0
|
|
|
|
self.index_lengths.append(0)
|
|
|
|
elif has_enough_frames: # add one valid data point
|
|
self.index[-1].append(self.candidate_index)
|
|
self.candidate_index += 1
|
|
self.size += 1
|
|
self.index_lengths[-1] += 1
|
|
|
|
def clear(self):
|
|
"""
|
|
Empties the data buffer and prepares to collect a new batch of data.
|
|
"""
|
|
del self.data
|
|
del self.index
|
|
del self.index_lengths
|
|
|
|
gc.collect()
|
|
|
|
self.data = [[]]
|
|
self.index = [[]]
|
|
self.candidate_index = 0
|
|
self.size = 0
|
|
self.index_lengths = [0]
|
|
|
|
def sample(self, batch_size):
|
|
"""
|
|
Performs uniform random sampling on ``self.index``. For simplicity, we do random sampling with replacement
|
|
for now with time O(``batch_size``). Fastest sampling without replacement seems to have to be of time
|
|
O(``batch_size`` * log(num_episodes)).
|
|
|
|
:param batch_size: An int. The size of the minibatch.
|
|
|
|
:return: A list of list of the sampled indexes. Episodes without sampled data points
|
|
correspond to empty sub-lists.
|
|
"""
|
|
# TODO: move unified properties and methods to base. but this may depend on how to deal with nstep
|
|
|
|
prob_episode = np.array(self.index_lengths) * 1. / self.size
|
|
num_episodes = len(self.index)
|
|
sampled_index = [[] for _ in range(num_episodes)]
|
|
|
|
for _ in range(batch_size):
|
|
# sample which episode
|
|
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
|
|
|
|
# sample which data point within the sampled episode
|
|
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
|
|
sampled_index[sampled_episode_i].append(sampled_frame_i)
|
|
|
|
return sampled_index
|
|
|
|
def statistics(self, discount_factor=0.99):
|
|
"""
|
|
Computes and prints out the statistics (e.g., discounted returns, undiscounted returns) in the batch set.
|
|
This is useful when policies are optimized by on-policy algorithms, so the current data in
|
|
the batch set directly reflect the performance of the current policy.
|
|
|
|
:param discount_factor: Optional. A float in range :math:`[0, 1]` defaulting to 0.99. The discount
|
|
factor to compute discounted returns.
|
|
"""
|
|
returns = []
|
|
undiscounted_returns = []
|
|
|
|
if len(self.data) == 1:
|
|
data = self.data
|
|
logging.warning('The first episode in BatchSet is still not finished. '
|
|
'Logging its return anyway.')
|
|
else:
|
|
data = self.data[:-1]
|
|
|
|
for episode in data:
|
|
current_return = 0.
|
|
current_undiscounted_return = 0.
|
|
current_discount = 1.
|
|
for frame in episode:
|
|
current_return += frame[REWARD] * current_discount
|
|
current_undiscounted_return += frame[REWARD]
|
|
current_discount *= discount_factor
|
|
returns.append(current_return)
|
|
undiscounted_returns.append(current_undiscounted_return)
|
|
|
|
mean_return = np.mean(returns)
|
|
mean_undiscounted_return = np.mean(undiscounted_returns)
|
|
|
|
print('Mean return: {}'.format(mean_return))
|
|
print('Mean undiscounted return: {}'.format(mean_undiscounted_return))
|