Merge branch 'main' into memmaze
This commit is contained in:
commit
2a8b2e84e0
@ -19,6 +19,10 @@ Train the agent on Alien in Atari 100K:
|
|||||||
```
|
```
|
||||||
python3 dreamer.py --configs atari100k --task atari_alien --logdir ./logdir/atari_alien
|
python3 dreamer.py --configs atari100k --task atari_alien --logdir ./logdir/atari_alien
|
||||||
```
|
```
|
||||||
|
Train the agent on Crafter:
|
||||||
|
```
|
||||||
|
python3 dreamer.py --configs crafter --logdir ./logdir/crafter
|
||||||
|
```
|
||||||
Monitor results:
|
Monitor results:
|
||||||
```
|
```
|
||||||
tensorboard --logdir ./logdir
|
tensorboard --logdir ./logdir
|
||||||
|
38
configs.yaml
38
configs.yaml
@ -120,22 +120,46 @@ defaults:
|
|||||||
disag_units: 400
|
disag_units: 400
|
||||||
disag_action_cond: False
|
disag_action_cond: False
|
||||||
|
|
||||||
dmc_vision:
|
|
||||||
steps: 1e6
|
|
||||||
train_ratio: 512
|
|
||||||
video_pred_log: true
|
|
||||||
encoder: {mlp_keys: '$^', cnn_keys: 'image'}
|
|
||||||
decoder: {mlp_keys: '$^', cnn_keys: 'image'}
|
|
||||||
|
|
||||||
dmc_proprio:
|
dmc_proprio:
|
||||||
steps: 5e5
|
steps: 5e5
|
||||||
|
action_repeat: 2
|
||||||
|
envs: 4
|
||||||
train_ratio: 512
|
train_ratio: 512
|
||||||
video_pred_log: false
|
video_pred_log: false
|
||||||
encoder: {mlp_keys: '.*', cnn_keys: '$^'}
|
encoder: {mlp_keys: '.*', cnn_keys: '$^'}
|
||||||
decoder: {mlp_keys: '.*', cnn_keys: '$^'}
|
decoder: {mlp_keys: '.*', cnn_keys: '$^'}
|
||||||
|
|
||||||
|
dmc_vision:
|
||||||
|
steps: 1e6
|
||||||
|
action_repeat: 2
|
||||||
|
envs: 4
|
||||||
|
train_ratio: 512
|
||||||
|
video_pred_log: true
|
||||||
|
encoder: {mlp_keys: '$^', cnn_keys: 'image'}
|
||||||
|
decoder: {mlp_keys: '$^', cnn_keys: 'image'}
|
||||||
|
|
||||||
|
crafter:
|
||||||
|
task: crafter_reward
|
||||||
|
step: 1e6
|
||||||
|
action_repeat: 1
|
||||||
|
envs: 1
|
||||||
|
train_ratio: 512
|
||||||
|
video_pred_log: true
|
||||||
|
dyn_hidden: 1024
|
||||||
|
dyn_deter: 4096
|
||||||
|
units: 1024
|
||||||
|
reward_layers: 5
|
||||||
|
cont_layers: 5
|
||||||
|
value_layers: 5
|
||||||
|
actor_layers: 5
|
||||||
|
encoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
||||||
|
decoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
||||||
|
actor_dist: 'onehot'
|
||||||
|
imag_gradient: 'reinforce'
|
||||||
|
|
||||||
atari100k:
|
atari100k:
|
||||||
steps: 4e5
|
steps: 4e5
|
||||||
|
envs: 1
|
||||||
action_repeat: 4
|
action_repeat: 4
|
||||||
eval_episode_num: 100
|
eval_episode_num: 100
|
||||||
stickey: False
|
stickey: False
|
||||||
|
10
dreamer.py
10
dreamer.py
@ -55,7 +55,9 @@ class Dreamer(nn.Module):
|
|||||||
self._task_behavior = models.ImagBehavior(
|
self._task_behavior = models.ImagBehavior(
|
||||||
config, self._wm, config.behavior_stop_grad
|
config, self._wm, config.behavior_stop_grad
|
||||||
)
|
)
|
||||||
if config.compile and os.name != 'nt': # compilation is not supported on windows
|
if (
|
||||||
|
config.compile and os.name != "nt"
|
||||||
|
): # compilation is not supported on windows
|
||||||
self._wm = torch.compile(self._wm)
|
self._wm = torch.compile(self._wm)
|
||||||
self._task_behavior = torch.compile(self._task_behavior)
|
self._task_behavior = torch.compile(self._task_behavior)
|
||||||
reward = lambda f, s, a: self._wm.heads["reward"](f).mean()
|
reward = lambda f, s, a: self._wm.heads["reward"](f).mean()
|
||||||
@ -156,7 +158,6 @@ class Dreamer(nn.Module):
|
|||||||
post, context, mets = self._wm._train(data)
|
post, context, mets = self._wm._train(data)
|
||||||
metrics.update(mets)
|
metrics.update(mets)
|
||||||
start = post
|
start = post
|
||||||
# start['deter'] (16, 64, 512)
|
|
||||||
reward = lambda f, s, a: self._wm.heads["reward"](
|
reward = lambda f, s, a: self._wm.heads["reward"](
|
||||||
self._wm.dynamics.get_feat(s)
|
self._wm.dynamics.get_feat(s)
|
||||||
).mode()
|
).mode()
|
||||||
@ -221,6 +222,11 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
|||||||
from envs.memorymaze import MemoryMaze
|
from envs.memorymaze import MemoryMaze
|
||||||
env = MemoryMaze(env)
|
env = MemoryMaze(env)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
|
elif suite == "crafter":
|
||||||
|
import envs.crafter as crafter
|
||||||
|
|
||||||
|
env = crafter.Crafter(task, config.size)
|
||||||
|
env = wrappers.OneHotAction(env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(suite)
|
raise NotImplementedError(suite)
|
||||||
env = wrappers.TimeLimit(env, config.time_limit)
|
env = wrappers.TimeLimit(env, config.time_limit)
|
||||||
|
70
envs/crafter.py
Normal file
70
envs/crafter.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class Crafter:
|
||||||
|
def __init__(self, task, size=(64, 64), seed=None):
|
||||||
|
assert task in ("reward", "noreward")
|
||||||
|
import crafter
|
||||||
|
|
||||||
|
self._env = crafter.Env(size=size, reward=(task == "reward"), seed=seed)
|
||||||
|
self._achievements = crafter.constants.achievements.copy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_space(self):
|
||||||
|
spaces = {
|
||||||
|
"image": gym.spaces.Box(
|
||||||
|
0, 255, self._env.observation_space.shape, dtype=np.uint8
|
||||||
|
),
|
||||||
|
"reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
||||||
|
"is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
||||||
|
"is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
||||||
|
"is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
|
||||||
|
"log_reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
|
||||||
|
}
|
||||||
|
spaces.update(
|
||||||
|
{
|
||||||
|
f"log_achievement_{k}": gym.spaces.Box(
|
||||||
|
-np.inf, np.inf, (1,), dtype=np.float32
|
||||||
|
)
|
||||||
|
for k in self._achievements
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return gym.spaces.Dict(spaces)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_space(self):
|
||||||
|
action_space = self._env.action_space
|
||||||
|
action_space.discrete = True
|
||||||
|
return action_space
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
image, reward, done, info = self._env.step(action)
|
||||||
|
reward = np.float32(reward)
|
||||||
|
log_achievements = {
|
||||||
|
f"log_achievement_{k}": info["achievements"][k] if info else 0
|
||||||
|
for k in self._achievements
|
||||||
|
}
|
||||||
|
obs = {
|
||||||
|
"image": image,
|
||||||
|
"reward": reward,
|
||||||
|
"is_first": False,
|
||||||
|
"is_last": done,
|
||||||
|
"is_terminal": info["discount"] == 0,
|
||||||
|
"log_reward": np.float32(info["reward"] if info else 0.0),
|
||||||
|
**log_achievements,
|
||||||
|
}
|
||||||
|
return obs, reward, done, info
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
return self._env.render()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
image = self._env.reset()
|
||||||
|
obs = {
|
||||||
|
"image": image,
|
||||||
|
"is_first": True,
|
||||||
|
"is_last": False,
|
||||||
|
"is_terminal": False,
|
||||||
|
}
|
||||||
|
return obs
|
@ -179,18 +179,22 @@ class RewardObs:
|
|||||||
@property
|
@property
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
spaces = self._env.observation_space.spaces
|
spaces = self._env.observation_space.spaces
|
||||||
assert "reward" not in spaces
|
if "reward" not in spaces:
|
||||||
spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32)
|
spaces["reward"] = gym.spaces.Box(
|
||||||
|
-np.inf, np.inf, shape=(1,), dtype=np.float32
|
||||||
|
)
|
||||||
return gym.spaces.Dict(spaces)
|
return gym.spaces.Dict(spaces)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs, reward, done, info = self._env.step(action)
|
obs, reward, done, info = self._env.step(action)
|
||||||
obs["reward"] = reward
|
if "reward" not in obs:
|
||||||
|
obs["reward"] = reward
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
obs = self._env.reset()
|
obs = self._env.reset()
|
||||||
obs["reward"] = 0.0
|
if "reward" not in obs:
|
||||||
|
obs["reward"] = 0.0
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,7 +58,9 @@ class Plan2Explore(nn.Module):
|
|||||||
"feat": config.dyn_stoch + config.dyn_deter,
|
"feat": config.dyn_stoch + config.dyn_deter,
|
||||||
}[self._config.disag_target]
|
}[self._config.disag_target]
|
||||||
kw = dict(
|
kw = dict(
|
||||||
inp_dim=feat_size + config.num_actions if config.disag_action_cond else 0, # pytorch version
|
inp_dim=feat_size + config.num_actions
|
||||||
|
if config.disag_action_cond
|
||||||
|
else 0, # pytorch version
|
||||||
shape=size,
|
shape=size,
|
||||||
layers=config.disag_layers,
|
layers=config.disag_layers,
|
||||||
units=config.disag_units,
|
units=config.disag_units,
|
||||||
@ -93,7 +95,9 @@ class Plan2Explore(nn.Module):
|
|||||||
}[self._config.disag_target]
|
}[self._config.disag_target]
|
||||||
inputs = context["feat"]
|
inputs = context["feat"]
|
||||||
if self._config.disag_action_cond:
|
if self._config.disag_action_cond:
|
||||||
inputs = torch.concat([inputs, torch.Tensor(data["action"]).to(self._config.device)], -1)
|
inputs = torch.concat(
|
||||||
|
[inputs, torch.Tensor(data["action"]).to(self._config.device)], -1
|
||||||
|
)
|
||||||
metrics.update(self._train_ensemble(inputs, target))
|
metrics.update(self._train_ensemble(inputs, target))
|
||||||
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
||||||
return None, metrics
|
return None, metrics
|
||||||
|
@ -399,13 +399,10 @@ class ImagBehavior(nn.Module):
|
|||||||
if self._config.future_entropy and self._config.actor_state_entropy() > 0:
|
if self._config.future_entropy and self._config.actor_state_entropy() > 0:
|
||||||
reward += self._config.actor_state_entropy() * state_ent
|
reward += self._config.actor_state_entropy() * state_ent
|
||||||
value = self.value(imag_feat).mode()
|
value = self.value(imag_feat).mode()
|
||||||
# value(15, 960, ch)
|
|
||||||
# action(15, 960, ch)
|
|
||||||
# discount(15, 960, ch)
|
|
||||||
target = tools.lambda_return(
|
target = tools.lambda_return(
|
||||||
reward[:-1],
|
reward[1:],
|
||||||
value[:-1],
|
value[:-1],
|
||||||
discount[:-1],
|
discount[1:],
|
||||||
bootstrap=value[-1],
|
bootstrap=value[-1],
|
||||||
lambda_=self._config.discount_lambda,
|
lambda_=self._config.discount_lambda,
|
||||||
axis=0,
|
axis=0,
|
||||||
|
11
networks.py
11
networks.py
@ -215,7 +215,9 @@ class RSSM(nn.Module):
|
|||||||
is_first,
|
is_first,
|
||||||
is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
|
is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
|
||||||
)
|
)
|
||||||
prev_state[key] = val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
prev_state[key] = (
|
||||||
|
val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
||||||
|
)
|
||||||
|
|
||||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
prior = self.img_step(prev_state, prev_action, None, sample)
|
||||||
if self._shared:
|
if self._shared:
|
||||||
@ -345,7 +347,11 @@ class MultiEncoder(nn.Module):
|
|||||||
):
|
):
|
||||||
super(MultiEncoder, self).__init__()
|
super(MultiEncoder, self).__init__()
|
||||||
excluded = ("is_first", "is_last", "is_terminal", "reward")
|
excluded = ("is_first", "is_last", "is_terminal", "reward")
|
||||||
shapes = {k: v for k, v in shapes.items() if k not in excluded}
|
shapes = {
|
||||||
|
k: v
|
||||||
|
for k, v in shapes.items()
|
||||||
|
if k not in excluded and not k.startswith("log_")
|
||||||
|
}
|
||||||
self.cnn_shapes = {
|
self.cnn_shapes = {
|
||||||
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
|
k: v for k, v in shapes.items() if len(v) == 3 and re.match(cnn_keys, k)
|
||||||
}
|
}
|
||||||
@ -547,6 +553,7 @@ class ConvDecoder(nn.Module):
|
|||||||
self._embed_size = minres**2 * depth * 2 ** (layer_num - 1)
|
self._embed_size = minres**2 * depth * 2 ** (layer_num - 1)
|
||||||
|
|
||||||
self._linear_layer = nn.Linear(feat_size, self._embed_size)
|
self._linear_layer = nn.Linear(feat_size, self._embed_size)
|
||||||
|
self._linear_layer.apply(tools.weight_init)
|
||||||
in_dim = self._embed_size // (minres**2)
|
in_dim = self._embed_size // (minres**2)
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
|
4
tools.py
4
tools.py
@ -804,7 +804,9 @@ def weight_init(m):
|
|||||||
denoms = (in_num + out_num) / 2.0
|
denoms = (in_num + out_num) / 2.0
|
||||||
scale = 1.0 / denoms
|
scale = 1.0 / denoms
|
||||||
std = np.sqrt(scale) / 0.87962566103423978
|
std = np.sqrt(scale) / 0.87962566103423978
|
||||||
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
|
nn.init.trunc_normal_(
|
||||||
|
m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std
|
||||||
|
)
|
||||||
if hasattr(m.bias, "data"):
|
if hasattr(m.bias, "data"):
|
||||||
m.bias.data.fill_(0.0)
|
m.bias.data.fill_(0.0)
|
||||||
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user