put running episode into replay buffer

This commit is contained in:
NM512 2023-04-24 06:25:17 +09:00
parent 6f0e6c6963
commit 432a359bcf
3 changed files with 27 additions and 3 deletions

View File

@ -225,7 +225,7 @@ def make_env(config, logger, mode, train_eps, eval_eps):
eval_eps,
)
]
env = wrappers.CollectDataset(env, callbacks)
env = wrappers.CollectDataset(env, mode, train_eps, callbacks=callbacks)
env = wrappers.RewardObs(env)
return env

View File

@ -1,13 +1,18 @@
import gym
import numpy as np
import uuid
class CollectDataset:
def __init__(self, env, callbacks=None, precision=32):
def __init__(
self, env, mode, train_eps, eval_eps=dict(), callbacks=None, precision=32
):
self._env = env
self._callbacks = callbacks or ()
self._precision = precision
self._episode = None
self._cache = dict(train=train_eps, eval=eval_eps)[mode]
self._temp_name = str(uuid.uuid4())
def __getattr__(self, name):
return getattr(self._env, name)
@ -23,7 +28,11 @@ class CollectDataset:
transition["reward"] = reward
transition["discount"] = info.get("discount", np.array(1 - float(done)))
self._episode.append(transition)
self.add_to_cache(transition)
if done:
# detele transitions before whole episode is stored
del self._cache[self._temp_name]
self._temp_name = str(uuid.uuid4())
for key, value in self._episode[1].items():
if key not in self._episode[0]:
self._episode[0][key] = 0 * value
@ -43,8 +52,23 @@ class CollectDataset:
transition["reward"] = 0.0
transition["discount"] = 1.0
self._episode = [transition]
self.add_to_cache(transition)
return obs
def add_to_cache(self, transition):
if self._temp_name not in self._cache:
self._cache[self._temp_name] = dict()
for key, val in transition.items():
self._cache[self._temp_name][key] = [self._convert(val)]
else:
for key, val in transition.items():
if key not in self._cache[self._temp_name]:
# fill missing data(action)
self._cache[self._temp_name][key] = [self._convert(0 * val)]
self._cache[self._temp_name][key].append(self._convert(val))
else:
self._cache[self._temp_name][key].append(self._convert(val))
def _convert(self, value):
value = np.array(value)
if np.issubdtype(value.dtype, np.floating):

View File

@ -207,7 +207,7 @@ def sample_episodes(episodes, length=None, balance=False, seed=0):
total = len(next(iter(episode.values())))
available = total - length
if available < 1:
print(f"Skipped short episode of length {available}.")
# print(f"Skipped short episode of length {available}.")
continue
if balance:
index = min(random.randint(0, total), available)