added benchmark task Crafter

This commit is contained in:
NM512 2023-06-18 00:02:22 +09:00
parent 9c58ab62c0
commit 5dce8cf13b
6 changed files with 111 additions and 5 deletions

View File

@ -19,6 +19,10 @@ Train the agent on Alien in Atari 100K:
```
python3 dreamer.py --configs atari100k --task atari_alien --logdir ./logdir/atari_alien
```
Train the agent on Crafter:
```
python3 dreamer.py --configs crafter --logdir ./logdir/crafter
```
Monitor results:
```
tensorboard --logdir ./logdir

View File

@ -138,6 +138,25 @@ dmc_vision:
encoder: {mlp_keys: '$^', cnn_keys: 'image'}
decoder: {mlp_keys: '$^', cnn_keys: 'image'}
crafter:
task: crafter_reward
step: 1e6
action_repeat: 1
envs: 1
train_ratio: 512
video_pred_log: true
dyn_hidden: 1024
dyn_deter: 4096
units: 1024
reward_layers: 5
cont_layers: 5
value_layers: 5
actor_layers: 5
encoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
decoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
actor_dist: 'onehot'
imag_gradient: 'reinforce'
atari100k:
steps: 4e5
envs: 1

View File

@ -211,6 +211,11 @@ def make_env(config, logger, mode, train_eps, eval_eps):
task, mode if "train" in mode else "test", config.action_repeat
)
env = wrappers.OneHotAction(env)
elif suite == "crafter":
import envs.crafter as crafter
env = crafter.Crafter(task, config.size)
env = wrappers.OneHotAction(env)
else:
raise NotImplementedError(suite)
env = wrappers.TimeLimit(env, config.time_limit)

70
envs/crafter.py Normal file
View File

@ -0,0 +1,70 @@
import gym
import numpy as np
class Crafter:
def __init__(self, task, size=(64, 64), seed=None):
assert task in ("reward", "noreward")
import crafter
self._env = crafter.Env(size=size, reward=(task == "reward"), seed=seed)
self._achievements = crafter.constants.achievements.copy()
@property
def observation_space(self):
spaces = {
"image": gym.spaces.Box(
0, 255, self._env.observation_space.shape, dtype=np.uint8
),
"reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
"is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"log_reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
}
spaces.update(
{
f"log_achievement_{k}": gym.spaces.Box(
-np.inf, np.inf, (1,), dtype=np.float32
)
for k in self._achievements
}
)
return gym.spaces.Dict(spaces)
@property
def action_space(self):
action_space = self._env.action_space
action_space.discrete = True
return action_space
def step(self, action):
image, reward, done, info = self._env.step(action)
reward = np.float32(reward)
log_achievements = {
f"log_achievement_{k}": info["achievements"][k] if info else 0
for k in self._achievements
}
obs = {
"image": image,
"reward": reward,
"is_first": False,
"is_last": done,
"is_terminal": info["discount"] == 0,
"log_reward": np.float32(info["reward"] if info else 0.0),
**log_achievements,
}
return obs, reward, done, info
def render(self):
return self._env.render()
def reset(self):
image = self._env.reset()
obs = {
"image": image,
"is_first": True,
"is_last": False,
"is_terminal": False,
}
return obs

View File

@ -179,18 +179,22 @@ class RewardObs:
@property
def observation_space(self):
spaces = self._env.observation_space.spaces
assert "reward" not in spaces
spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32)
if "reward" not in spaces:
spaces["reward"] = gym.spaces.Box(
-np.inf, np.inf, shape=(1,), dtype=np.float32
)
return gym.spaces.Dict(spaces)
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs["reward"] = reward
if "reward" not in obs:
obs["reward"] = reward
return obs, reward, done, info
def reset(self):
obs = self._env.reset()
obs["reward"] = 0.0
if "reward" not in obs:
obs["reward"] = 0.0
return obs

View File

@ -347,7 +347,11 @@ class MultiEncoder(nn.Module):
):
super(MultiEncoder, self).__init__()
excluded = ("is_first", "is_last", "is_terminal", "reward")
shapes = {k: v for k, v in shapes.items() if k not in excluded}
shapes = {
k: v
for k, v in shapes.items()
if k not in excluded and not k.startswith("log_")
}
self.cnn_shapes = {
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
}