diff --git a/.gitignore b/.gitignore index f40ff34..63e2c40 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.sh logdir* vis_* +result* # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/dreamer.py b/dreamer.py index f593669..fba6bce 100644 --- a/dreamer.py +++ b/dreamer.py @@ -337,6 +337,7 @@ def main(config): acts = train_envs[0].action_space config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0] + state = None if not config.offline_traindir: prefill = max(0, config.prefill - count_steps(config.traindir)) print(f"Prefill dataset ({prefill} steps).") @@ -358,7 +359,7 @@ def main(config): logprob = random_actor.log_prob(action) return {"action": action, "logprob": logprob}, None - tools.simulate(random_agent, train_envs, prefill) + state = tools.simulate(random_agent, train_envs, prefill) logger.step = config.action_repeat * count_steps(config.traindir) print("Simulate agent.") @@ -376,7 +377,6 @@ def main(config): agent.load_state_dict(torch.load(logdir / "latest_model.pt")) agent._should_pretrain._once = False - state = None while agent._step < config.steps: logger.write() print("Start evaluation.") diff --git a/tools.py b/tools.py index 14aa42d..59da93a 100644 --- a/tools.py +++ b/tools.py @@ -161,7 +161,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None): done = np.stack(done) episode += int(done.sum()) length += 1 - step += (done * length).sum() + step += len(envs) length *= 1 - done return (step - steps, episode - episodes, done, length, obs, agent_state, reward)