From 0ae6d2d1e0cc9ae7b4515e30eb6cf77e646fa747 Mon Sep 17 00:00:00 2001 From: NM512 Date: Sun, 2 Jul 2023 11:51:11 +0900 Subject: [PATCH] step-based counting --- .gitignore | 1 + dreamer.py | 4 ++-- tools.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index f40ff34..63e2c40 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.sh logdir* vis_* +result* # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/dreamer.py b/dreamer.py index f593669..fba6bce 100644 --- a/dreamer.py +++ b/dreamer.py @@ -337,6 +337,7 @@ def main(config): acts = train_envs[0].action_space config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0] + state = None if not config.offline_traindir: prefill = max(0, config.prefill - count_steps(config.traindir)) print(f"Prefill dataset ({prefill} steps).") @@ -358,7 +359,7 @@ def main(config): logprob = random_actor.log_prob(action) return {"action": action, "logprob": logprob}, None - tools.simulate(random_agent, train_envs, prefill) + state = tools.simulate(random_agent, train_envs, prefill) logger.step = config.action_repeat * count_steps(config.traindir) print("Simulate agent.") @@ -376,7 +377,6 @@ def main(config): agent.load_state_dict(torch.load(logdir / "latest_model.pt")) agent._should_pretrain._once = False - state = None while agent._step < config.steps: logger.write() print("Start evaluation.") diff --git a/tools.py b/tools.py index 14aa42d..59da93a 100644 --- a/tools.py +++ b/tools.py @@ -161,7 +161,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None): done = np.stack(done) episode += int(done.sum()) length += 1 - step += (done * length).sum() + step += len(envs) length *= 1 - done return (step - steps, episode - episodes, done, length, obs, agent_state, reward)