107 lines
3.4 KiB
Python
Raw Normal View History

import gc
import numpy as np
import logging
from .base import DataBufferBase
STATE = 0
ACTION = 1
REWARD = 2
DONE = 3
class BatchSet(DataBufferBase):
"""
class for batched dataset as used in on-policy algos
"""
def __init__(self, nstep=None):
self.nstep = nstep or 1 # RL has to look ahead at least one timestep
self.data = [[]]
self.index = [[]]
self.candidate_index = 0
self.size = 0 # number of valid data points (not frames)
self.index_lengths = [0] # for sampling
def add(self, frame):
self.data[-1].append(frame)
has_enough_frames = len(self.data[-1]) > self.nstep
if frame[DONE]: # episode terminates, all trailing frames become valid data points
trailing_index = list(range(self.candidate_index, len(self.data[-1])))
self.index[-1] += trailing_index
self.size += len(trailing_index)
self.index_lengths[-1] += len(trailing_index)
# prepare for the next episode
self.data.append([])
self.index.append([])
self.candidate_index = 0
self.index_lengths.append(0)
elif has_enough_frames: # add one valid data point
self.index[-1].append(self.candidate_index)
self.candidate_index += 1
self.size += 1
self.index_lengths[-1] += 1
def clear(self):
del self.data
del self.index
del self.index_lengths
gc.collect()
self.data = [[]]
self.index = [[]]
self.candidate_index = 0
self.size = 0
self.index_lengths = [0]
def sample(self, batch_size):
# TODO: move unified properties and methods to base. but this depends on how to deal with nstep
prob_episode = np.array(self.index_lengths) * 1. / self.size
num_episodes = len(self.index)
sampled_index = [[] for _ in range(num_episodes)]
for _ in range(batch_size):
# sample which episode
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
# sample which data point within the sampled episode
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
sampled_index[sampled_episode_i].append(sampled_frame_i)
return sampled_index
def statistics(self, discount_factor=0.99):
returns = []
undiscounted_returns = []
if len(self.data) == 1:
data = self.data
logging.warning('The first episode in BatchSet is still not finished. '
'Logging its return anyway.')
else:
data = self.data[:-1]
for episode in data:
current_return = 0.
current_undiscounted_return = 0.
current_discount = 1.
for frame in episode:
current_return += frame[REWARD] * current_discount
current_undiscounted_return += frame[REWARD]
current_discount *= discount_factor
returns.append(current_return)
undiscounted_returns.append(current_undiscounted_return)
mean_return = np.mean(returns)
mean_undiscounted_return = np.mean(undiscounted_returns)
logging.info('Mean return: {}'.format(mean_return))
logging.info('Mean undiscounted return: {}'.format(mean_undiscounted_return))