sampling from the replay buffer across episodes
This commit is contained in:
parent
12cccd8475
commit
1328ff1088
@ -12,7 +12,7 @@ defaults:
|
||||
log_every: 1e4
|
||||
reset_every: 0
|
||||
device: 'cuda:0'
|
||||
compile: False
|
||||
compile: True
|
||||
precision: 16
|
||||
debug: False
|
||||
expl_gifs: False
|
||||
@ -78,7 +78,6 @@ defaults:
|
||||
value_grad_clip: 100
|
||||
actor_grad_clip: 100
|
||||
dataset_size: 1000000
|
||||
oversample_ends: True
|
||||
slow_value_target: True
|
||||
slow_target_update: 1
|
||||
slow_target_fraction: 0.02
|
||||
|
@ -174,9 +174,7 @@ def count_steps(folder):
|
||||
|
||||
|
||||
def make_dataset(episodes, config):
|
||||
generator = tools.sample_episodes(
|
||||
episodes, config.batch_length, config.oversample_ends
|
||||
)
|
||||
generator = tools.sample_episodes(episodes, config.batch_length)
|
||||
dataset = tools.from_generator(generator, config.batch_size)
|
||||
return dataset
|
||||
|
||||
|
38
tools.py
38
tools.py
@ -199,22 +199,38 @@ def from_generator(generator, batch_size):
|
||||
yield data
|
||||
|
||||
|
||||
def sample_episodes(episodes, length=None, balance=False, seed=0):
|
||||
def sample_episodes(episodes, length, seed=0):
|
||||
random = np.random.RandomState(seed)
|
||||
while True:
|
||||
episode = random.choice(list(episodes.values()))
|
||||
if length:
|
||||
size = 0
|
||||
ret = None
|
||||
p = np.array(
|
||||
[len(next(iter(episode.values()))) for episode in episodes.values()]
|
||||
)
|
||||
p = p / np.sum(p)
|
||||
while size < length:
|
||||
episode = random.choice(list(episodes.values()), p=p)
|
||||
total = len(next(iter(episode.values())))
|
||||
available = total - length
|
||||
if available < 1:
|
||||
# print(f"Skipped short episode of length {available}.")
|
||||
# make sure at least one transition included
|
||||
if total < 2:
|
||||
continue
|
||||
if balance:
|
||||
index = min(random.randint(0, total), available)
|
||||
if not ret:
|
||||
index = int(random.randint(0, total - 1))
|
||||
ret = {
|
||||
k: v[index : min(index + length, total)] for k, v in episode.items()
|
||||
}
|
||||
else:
|
||||
index = int(random.randint(0, available + 1))
|
||||
episode = {k: v[index : index + length] for k, v in episode.items()}
|
||||
yield episode
|
||||
# 'is_first' comes after 'is_last'
|
||||
index = 0
|
||||
possible = length - size
|
||||
ret = {
|
||||
k: np.append(
|
||||
ret[k], v[index : min(index + possible, total)], axis=0
|
||||
)
|
||||
for k, v in episode.items()
|
||||
}
|
||||
size = len(next(iter(ret.values())))
|
||||
yield ret
|
||||
|
||||
|
||||
def load_episodes(directory, limit=None, reverse=True):
|
||||
|
Loading…
x
Reference in New Issue
Block a user