modifications for minecraft
This commit is contained in:
parent
8c471e12d6
commit
8571cf656a
@ -177,6 +177,9 @@ minecraft:
|
|||||||
task: minecraft_diamond
|
task: minecraft_diamond
|
||||||
step: 1e8
|
step: 1e8
|
||||||
envs: 16
|
envs: 16
|
||||||
|
# no eval
|
||||||
|
eval_episode_num: 0
|
||||||
|
eval_every: 1e8
|
||||||
action_repeat: 1
|
action_repeat: 1
|
||||||
train_ratio: 16
|
train_ratio: 16
|
||||||
video_pred_log: true
|
video_pred_log: true
|
||||||
|
32
dreamer.py
32
dreamer.py
@ -230,7 +230,8 @@ def make_env(config, mode):
|
|||||||
env = wrappers.TimeLimit(env, config.time_limit)
|
env = wrappers.TimeLimit(env, config.time_limit)
|
||||||
env = wrappers.SelectAction(env, key="action")
|
env = wrappers.SelectAction(env, key="action")
|
||||||
env = wrappers.UUID(env)
|
env = wrappers.UUID(env)
|
||||||
env = wrappers.RewardObs(env)
|
if suite == "minecraft":
|
||||||
|
env = wrappers.RewardObs(env)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
@ -326,20 +327,21 @@ def main(config):
|
|||||||
# make sure eval will be executed once after config.steps
|
# make sure eval will be executed once after config.steps
|
||||||
while agent._step < config.steps + config.eval_every:
|
while agent._step < config.steps + config.eval_every:
|
||||||
logger.write()
|
logger.write()
|
||||||
print("Start evaluation.")
|
if config.eval_episode_num > 0:
|
||||||
eval_policy = functools.partial(agent, training=False)
|
print("Start evaluation.")
|
||||||
tools.simulate(
|
eval_policy = functools.partial(agent, training=False)
|
||||||
eval_policy,
|
tools.simulate(
|
||||||
eval_envs,
|
eval_policy,
|
||||||
eval_eps,
|
eval_envs,
|
||||||
config.evaldir,
|
eval_eps,
|
||||||
logger,
|
config.evaldir,
|
||||||
is_eval=True,
|
logger,
|
||||||
episodes=config.eval_episode_num,
|
is_eval=True,
|
||||||
)
|
episodes=config.eval_episode_num,
|
||||||
if config.video_pred_log:
|
)
|
||||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
if config.video_pred_log:
|
||||||
logger.video("eval_openl", to_np(video_pred))
|
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||||
|
logger.video("eval_openl", to_np(video_pred))
|
||||||
print("Start training.")
|
print("Start training.")
|
||||||
state = tools.simulate(
|
state = tools.simulate(
|
||||||
agent,
|
agent,
|
||||||
|
@ -415,7 +415,7 @@ class MultiDecoder(nn.Module):
|
|||||||
vector_dist,
|
vector_dist,
|
||||||
):
|
):
|
||||||
super(MultiDecoder, self).__init__()
|
super(MultiDecoder, self).__init__()
|
||||||
excluded = ("is_first", "is_last", "is_terminal", "reward")
|
excluded = ("is_first", "is_last", "is_terminal")
|
||||||
shapes = {k: v for k, v in shapes.items() if k not in excluded}
|
shapes = {k: v for k, v in shapes.items() if k not in excluded}
|
||||||
self.cnn_shapes = {
|
self.cnn_shapes = {
|
||||||
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
|
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user