modifications for minecraft

This commit is contained in:
NM512 2023-08-05 21:13:57 +09:00
parent 8c471e12d6
commit 8571cf656a
3 changed files with 21 additions and 16 deletions

View File

@ -177,6 +177,9 @@ minecraft:
task: minecraft_diamond task: minecraft_diamond
step: 1e8 step: 1e8
envs: 16 envs: 16
# no eval
eval_episode_num: 0
eval_every: 1e8
action_repeat: 1 action_repeat: 1
train_ratio: 16 train_ratio: 16
video_pred_log: true video_pred_log: true

View File

@ -230,7 +230,8 @@ def make_env(config, mode):
env = wrappers.TimeLimit(env, config.time_limit) env = wrappers.TimeLimit(env, config.time_limit)
env = wrappers.SelectAction(env, key="action") env = wrappers.SelectAction(env, key="action")
env = wrappers.UUID(env) env = wrappers.UUID(env)
env = wrappers.RewardObs(env) if suite == "minecraft":
env = wrappers.RewardObs(env)
return env return env
@ -326,20 +327,21 @@ def main(config):
# make sure eval will be executed once after config.steps # make sure eval will be executed once after config.steps
while agent._step < config.steps + config.eval_every: while agent._step < config.steps + config.eval_every:
logger.write() logger.write()
print("Start evaluation.") if config.eval_episode_num > 0:
eval_policy = functools.partial(agent, training=False) print("Start evaluation.")
tools.simulate( eval_policy = functools.partial(agent, training=False)
eval_policy, tools.simulate(
eval_envs, eval_policy,
eval_eps, eval_envs,
config.evaldir, eval_eps,
logger, config.evaldir,
is_eval=True, logger,
episodes=config.eval_episode_num, is_eval=True,
) episodes=config.eval_episode_num,
if config.video_pred_log: )
video_pred = agent._wm.video_pred(next(eval_dataset)) if config.video_pred_log:
logger.video("eval_openl", to_np(video_pred)) video_pred = agent._wm.video_pred(next(eval_dataset))
logger.video("eval_openl", to_np(video_pred))
print("Start training.") print("Start training.")
state = tools.simulate( state = tools.simulate(
agent, agent,

View File

@ -415,7 +415,7 @@ class MultiDecoder(nn.Module):
vector_dist, vector_dist,
): ):
super(MultiDecoder, self).__init__() super(MultiDecoder, self).__init__()
excluded = ("is_first", "is_last", "is_terminal", "reward") excluded = ("is_first", "is_last", "is_terminal")
shapes = {k: v for k, v in shapes.items() if k not in excluded} shapes = {k: v for k, v in shapes.items() if k not in excluded}
self.cnn_shapes = { self.cnn_shapes = {
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k) k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)