erased unnecessary lines of code
This commit is contained in:
parent
eb14e2488b
commit
8c471e12d6
53
tools.py
53
tools.py
@ -237,59 +237,6 @@ def simulate(
|
|||||||
return (step - steps, episode - episodes, done, length, obs, agent_state, reward)
|
return (step - steps, episode - episodes, done, length, obs, agent_state, reward)
|
||||||
|
|
||||||
|
|
||||||
class CollectDataset:
|
|
||||||
def __init__(
|
|
||||||
self, env, mode, train_eps, eval_eps=dict(), callbacks=None, precision=32
|
|
||||||
):
|
|
||||||
self._env = env
|
|
||||||
self._callbacks = callbacks or ()
|
|
||||||
self._precision = precision
|
|
||||||
self._episode = None
|
|
||||||
self._cache = dict(train=train_eps, eval=eval_eps)[mode]
|
|
||||||
self._temp_name = str(uuid.uuid4())
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
obs, reward, done, info = self._env.step(action)
|
|
||||||
obs = {k: self._convert(v) for k, v in obs.items()}
|
|
||||||
transition = obs.copy()
|
|
||||||
if isinstance(action, dict):
|
|
||||||
transition.update(action)
|
|
||||||
else:
|
|
||||||
transition["action"] = action
|
|
||||||
transition["reward"] = reward
|
|
||||||
transition["discount"] = info.get("discount", np.array(1 - float(done)))
|
|
||||||
self._episode.append(transition)
|
|
||||||
self.add_to_cache(transition)
|
|
||||||
if done:
|
|
||||||
# detele transitions before whole episode is stored
|
|
||||||
del self._cache[self._temp_name]
|
|
||||||
self._temp_name = str(uuid.uuid4())
|
|
||||||
for key, value in self._episode[1].items():
|
|
||||||
if key not in self._episode[0]:
|
|
||||||
self._episode[0][key] = 0 * value
|
|
||||||
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
|
|
||||||
episode = {k: self._convert(v) for k, v in episode.items()}
|
|
||||||
info["episode"] = episode
|
|
||||||
for callback in self._callbacks:
|
|
||||||
callback(episode)
|
|
||||||
return obs, reward, done, info
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
obs = self._env.reset()
|
|
||||||
transition = obs.copy()
|
|
||||||
# missing keys will be filled with a zeroed out version of the first
|
|
||||||
# transition, because we do not know what action information the agent will
|
|
||||||
# pass yet.
|
|
||||||
transition["reward"] = 0.0
|
|
||||||
transition["discount"] = 1.0
|
|
||||||
self._episode = [transition]
|
|
||||||
self.add_to_cache(transition)
|
|
||||||
return obs
|
|
||||||
|
|
||||||
|
|
||||||
def add_to_cache(cache, id, transition):
|
def add_to_cache(cache, id, transition):
|
||||||
if id not in cache:
|
if id not in cache:
|
||||||
cache[id] = dict()
|
cache[id] = dict()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user