From f07d8439535a1a63d99bfa6619e1c493bfe2812e Mon Sep 17 00:00:00 2001 From: NM512 Date: Sat, 22 Jul 2023 20:53:43 +0900 Subject: [PATCH] erased unnecessary reward input --- dreamer.py | 4 ++-- tools.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/dreamer.py b/dreamer.py index c7fa85e..0a98d3f 100644 --- a/dreamer.py +++ b/dreamer.py @@ -66,7 +66,7 @@ class Dreamer(nn.Module): plan2explore=lambda: expl.Plan2Explore(config, self._wm, reward), )[config.expl_behavior]().to(self._config.device) - def __call__(self, obs, reset, state=None, reward=None, training=True): + def __call__(self, obs, reset, state=None, training=True): step = self._step if self._should_reset(step): state = None @@ -295,7 +295,7 @@ def main(config): 1, ) - def random_agent(o, d, s, r): + def random_agent(o, d, s): action = random_actor.sample() logprob = random_actor.log_prob(action) return {"action": action, "logprob": logprob}, None diff --git a/tools.py b/tools.py index c210791..45048f5 100644 --- a/tools.py +++ b/tools.py @@ -148,10 +148,9 @@ def simulate(agent, envs, cache, directory, logger, is_eval=False, limit=None, s add_to_cache(cache, envs[i].id, t) for index, result in zip(indices, results): obs[index] = result - reward = [reward[i] * (1 - done[i]) for i in range(len(envs))] # Step agents. obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]} - action, agent_state = agent(obs, done, agent_state, reward) + action, agent_state = agent(obs, done, agent_state) if isinstance(action, dict): action = [ {k: np.array(action[k][i].detach().cpu()) for k in action}