From 8c471e12d65ae7857d7c20bf25fcf5e840100f47 Mon Sep 17 00:00:00 2001 From: NM512 Date: Sat, 5 Aug 2023 21:11:34 +0900 Subject: [PATCH] erased unnecessary lines of code --- tools.py | 53 ----------------------------------------------------- 1 file changed, 53 deletions(-) diff --git a/tools.py b/tools.py index fb34c60..1cfed7d 100644 --- a/tools.py +++ b/tools.py @@ -237,59 +237,6 @@ def simulate( return (step - steps, episode - episodes, done, length, obs, agent_state, reward) -class CollectDataset: - def __init__( - self, env, mode, train_eps, eval_eps=dict(), callbacks=None, precision=32 - ): - self._env = env - self._callbacks = callbacks or () - self._precision = precision - self._episode = None - self._cache = dict(train=train_eps, eval=eval_eps)[mode] - self._temp_name = str(uuid.uuid4()) - - def __getattr__(self, name): - return getattr(self._env, name) - - def step(self, action): - obs, reward, done, info = self._env.step(action) - obs = {k: self._convert(v) for k, v in obs.items()} - transition = obs.copy() - if isinstance(action, dict): - transition.update(action) - else: - transition["action"] = action - transition["reward"] = reward - transition["discount"] = info.get("discount", np.array(1 - float(done))) - self._episode.append(transition) - self.add_to_cache(transition) - if done: - # detele transitions before whole episode is stored - del self._cache[self._temp_name] - self._temp_name = str(uuid.uuid4()) - for key, value in self._episode[1].items(): - if key not in self._episode[0]: - self._episode[0][key] = 0 * value - episode = {k: [t[k] for t in self._episode] for k in self._episode[0]} - episode = {k: self._convert(v) for k, v in episode.items()} - info["episode"] = episode - for callback in self._callbacks: - callback(episode) - return obs, reward, done, info - - def reset(self): - obs = self._env.reset() - transition = obs.copy() - # missing keys will be filled with a zeroed out version of the first - # transition, because we do not know what action information the agent will - # pass yet. - transition["reward"] = 0.0 - transition["discount"] = 1.0 - self._episode = [transition] - self.add_to_cache(transition) - return obs - - def add_to_cache(cache, id, transition): if id not in cache: cache[id] = dict()