sampling from the replay buffer across episodes
This commit is contained in:
parent
12cccd8475
commit
1328ff1088
@ -12,7 +12,7 @@ defaults:
|
|||||||
log_every: 1e4
|
log_every: 1e4
|
||||||
reset_every: 0
|
reset_every: 0
|
||||||
device: 'cuda:0'
|
device: 'cuda:0'
|
||||||
compile: False
|
compile: True
|
||||||
precision: 16
|
precision: 16
|
||||||
debug: False
|
debug: False
|
||||||
expl_gifs: False
|
expl_gifs: False
|
||||||
@ -78,7 +78,6 @@ defaults:
|
|||||||
value_grad_clip: 100
|
value_grad_clip: 100
|
||||||
actor_grad_clip: 100
|
actor_grad_clip: 100
|
||||||
dataset_size: 1000000
|
dataset_size: 1000000
|
||||||
oversample_ends: True
|
|
||||||
slow_value_target: True
|
slow_value_target: True
|
||||||
slow_target_update: 1
|
slow_target_update: 1
|
||||||
slow_target_fraction: 0.02
|
slow_target_fraction: 0.02
|
||||||
|
@ -174,9 +174,7 @@ def count_steps(folder):
|
|||||||
|
|
||||||
|
|
||||||
def make_dataset(episodes, config):
|
def make_dataset(episodes, config):
|
||||||
generator = tools.sample_episodes(
|
generator = tools.sample_episodes(episodes, config.batch_length)
|
||||||
episodes, config.batch_length, config.oversample_ends
|
|
||||||
)
|
|
||||||
dataset = tools.from_generator(generator, config.batch_size)
|
dataset = tools.from_generator(generator, config.batch_size)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
38
tools.py
38
tools.py
@ -199,22 +199,38 @@ def from_generator(generator, batch_size):
|
|||||||
yield data
|
yield data
|
||||||
|
|
||||||
|
|
||||||
def sample_episodes(episodes, length=None, balance=False, seed=0):
|
def sample_episodes(episodes, length, seed=0):
|
||||||
random = np.random.RandomState(seed)
|
random = np.random.RandomState(seed)
|
||||||
while True:
|
while True:
|
||||||
episode = random.choice(list(episodes.values()))
|
size = 0
|
||||||
if length:
|
ret = None
|
||||||
|
p = np.array(
|
||||||
|
[len(next(iter(episode.values()))) for episode in episodes.values()]
|
||||||
|
)
|
||||||
|
p = p / np.sum(p)
|
||||||
|
while size < length:
|
||||||
|
episode = random.choice(list(episodes.values()), p=p)
|
||||||
total = len(next(iter(episode.values())))
|
total = len(next(iter(episode.values())))
|
||||||
available = total - length
|
# make sure at least one transition included
|
||||||
if available < 1:
|
if total < 2:
|
||||||
# print(f"Skipped short episode of length {available}.")
|
|
||||||
continue
|
continue
|
||||||
if balance:
|
if not ret:
|
||||||
index = min(random.randint(0, total), available)
|
index = int(random.randint(0, total - 1))
|
||||||
|
ret = {
|
||||||
|
k: v[index : min(index + length, total)] for k, v in episode.items()
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
index = int(random.randint(0, available + 1))
|
# 'is_first' comes after 'is_last'
|
||||||
episode = {k: v[index : index + length] for k, v in episode.items()}
|
index = 0
|
||||||
yield episode
|
possible = length - size
|
||||||
|
ret = {
|
||||||
|
k: np.append(
|
||||||
|
ret[k], v[index : min(index + possible, total)], axis=0
|
||||||
|
)
|
||||||
|
for k, v in episode.items()
|
||||||
|
}
|
||||||
|
size = len(next(iter(ret.values())))
|
||||||
|
yield ret
|
||||||
|
|
||||||
|
|
||||||
def load_episodes(directory, limit=None, reverse=True):
|
def load_episodes(directory, limit=None, reverse=True):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user