sampling from the replay buffer across episodes

This commit is contained in:
NM512 2023-04-29 07:43:02 +09:00
parent 12cccd8475
commit 1328ff1088
3 changed files with 29 additions and 16 deletions

View File

@ -12,7 +12,7 @@ defaults:
log_every: 1e4
reset_every: 0
device: 'cuda:0'
compile: False
compile: True
precision: 16
debug: False
expl_gifs: False
@ -78,7 +78,6 @@ defaults:
value_grad_clip: 100
actor_grad_clip: 100
dataset_size: 1000000
oversample_ends: True
slow_value_target: True
slow_target_update: 1
slow_target_fraction: 0.02

View File

@ -174,9 +174,7 @@ def count_steps(folder):
def make_dataset(episodes, config):
generator = tools.sample_episodes(
episodes, config.batch_length, config.oversample_ends
)
generator = tools.sample_episodes(episodes, config.batch_length)
dataset = tools.from_generator(generator, config.batch_size)
return dataset

View File

@ -199,22 +199,38 @@ def from_generator(generator, batch_size):
yield data
def sample_episodes(episodes, length=None, balance=False, seed=0):
def sample_episodes(episodes, length, seed=0):
random = np.random.RandomState(seed)
while True:
episode = random.choice(list(episodes.values()))
if length:
size = 0
ret = None
p = np.array(
[len(next(iter(episode.values()))) for episode in episodes.values()]
)
p = p / np.sum(p)
while size < length:
episode = random.choice(list(episodes.values()), p=p)
total = len(next(iter(episode.values())))
available = total - length
if available < 1:
# print(f"Skipped short episode of length {available}.")
# make sure at least one transition included
if total < 2:
continue
if balance:
index = min(random.randint(0, total), available)
if not ret:
index = int(random.randint(0, total - 1))
ret = {
k: v[index : min(index + length, total)] for k, v in episode.items()
}
else:
index = int(random.randint(0, available + 1))
episode = {k: v[index : index + length] for k, v in episode.items()}
yield episode
# 'is_first' comes after 'is_last'
index = 0
possible = length - size
ret = {
k: np.append(
ret[k], v[index : min(index + possible, total)], axis=0
)
for k, v in episode.items()
}
size = len(next(iter(ret.values())))
yield ret
def load_episodes(directory, limit=None, reverse=True):