step-based counting
This commit is contained in:
parent
036e9a8028
commit
0ae6d2d1e0
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,6 +2,7 @@
|
|||||||
*.sh
|
*.sh
|
||||||
logdir*
|
logdir*
|
||||||
vis_*
|
vis_*
|
||||||
|
result*
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
@ -337,6 +337,7 @@ def main(config):
|
|||||||
acts = train_envs[0].action_space
|
acts = train_envs[0].action_space
|
||||||
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
|
config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]
|
||||||
|
|
||||||
|
state = None
|
||||||
if not config.offline_traindir:
|
if not config.offline_traindir:
|
||||||
prefill = max(0, config.prefill - count_steps(config.traindir))
|
prefill = max(0, config.prefill - count_steps(config.traindir))
|
||||||
print(f"Prefill dataset ({prefill} steps).")
|
print(f"Prefill dataset ({prefill} steps).")
|
||||||
@ -358,7 +359,7 @@ def main(config):
|
|||||||
logprob = random_actor.log_prob(action)
|
logprob = random_actor.log_prob(action)
|
||||||
return {"action": action, "logprob": logprob}, None
|
return {"action": action, "logprob": logprob}, None
|
||||||
|
|
||||||
tools.simulate(random_agent, train_envs, prefill)
|
state = tools.simulate(random_agent, train_envs, prefill)
|
||||||
logger.step = config.action_repeat * count_steps(config.traindir)
|
logger.step = config.action_repeat * count_steps(config.traindir)
|
||||||
|
|
||||||
print("Simulate agent.")
|
print("Simulate agent.")
|
||||||
@ -376,7 +377,6 @@ def main(config):
|
|||||||
agent.load_state_dict(torch.load(logdir / "latest_model.pt"))
|
agent.load_state_dict(torch.load(logdir / "latest_model.pt"))
|
||||||
agent._should_pretrain._once = False
|
agent._should_pretrain._once = False
|
||||||
|
|
||||||
state = None
|
|
||||||
while agent._step < config.steps:
|
while agent._step < config.steps:
|
||||||
logger.write()
|
logger.write()
|
||||||
print("Start evaluation.")
|
print("Start evaluation.")
|
||||||
|
2
tools.py
2
tools.py
@ -161,7 +161,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None):
|
|||||||
done = np.stack(done)
|
done = np.stack(done)
|
||||||
episode += int(done.sum())
|
episode += int(done.sum())
|
||||||
length += 1
|
length += 1
|
||||||
step += (done * length).sum()
|
step += len(envs)
|
||||||
length *= 1 - done
|
length *= 1 - done
|
||||||
|
|
||||||
return (step - steps, episode - episodes, done, length, obs, agent_state, reward)
|
return (step - steps, episode - episodes, done, length, obs, agent_state, reward)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user