merged action head into MLP and modified configs

This commit is contained in:
NM512 2024-01-05 10:26:48 +09:00
parent e0f2017e28
commit e0487f8206
5 changed files with 133 additions and 231 deletions

View File

@ -47,26 +47,25 @@ defaults:
dyn_temp_post: True
grad_heads: ['decoder', 'reward', 'cont']
units: 512
reward_layers: 2
cont_layers: 2
value_layers: 2
actor_layers: 2
act: 'SiLU'
norm: 'LayerNorm'
encoder:
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
decoder:
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse}
value_head: 'symlog_disc'
reward_head: 'symlog_disc'
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse, outscale: 1.0}
actor:
{layers: 2, dist: 'normal', entropy: 3e-4, unimix_ratio: 0.01, min_std: 0.1, max_std: 1.0, temp: 0.1, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 1.0}
critic:
{layers: 2, dist: 'symlog_disc', slow_target: True, slow_target_update: 1, slow_target_fraction: 0.02, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 0.0}
reward_head:
{layers: 2, dist: 'symlog_disc', scale: 1.0, outscale: 0.0}
cont_head:
{layers: 2, scale: 1.0, outscale: 1.0}
dyn_scale: 0.5
rep_scale: 0.1
kl_free: 1.0
cont_scale: 1.0
reward_scale: 1.0
weight_decay: 0.0
unimix_ratio: 0.01
action_unimix_ratio: 0.01
initial: 'learned'
# Training
@ -77,15 +76,7 @@ defaults:
model_lr: 1e-4
opt_eps: 1e-8
grad_clip: 1000
value_lr: 3e-5
actor_lr: 3e-5
ac_opt_eps: 1e-5
value_grad_clip: 100
actor_grad_clip: 100
dataset_size: 1000000
slow_value_target: True
slow_target_update: 1
slow_target_fraction: 0.02
opt: 'adam'
# Behavior.
@ -95,18 +86,10 @@ defaults:
imag_gradient: 'dynamics'
imag_gradient_mix: 0.0
imag_sample: True
actor_dist: 'normal'
actor_entropy: 3e-4
actor_state_entropy: 0.0
actor_init_std: 1.0
actor_min_std: 0.1
actor_max_std: 1.0
actor_temp: 0.1
expl_amount: 0.0
expl_amount: 0
eval_state_mean: False
collect_dyn_sample: True
behavior_stop_grad: True
value_decay: 0.0
future_entropy: False
# Exploration
@ -150,13 +133,12 @@ crafter:
dyn_hidden: 1024
dyn_deter: 4096
units: 1024
reward_layers: 5
cont_layers: 5
value_layers: 5
actor_layers: 5
encoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
decoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
actor_dist: 'onehot'
actor: {layers: 5, dist: 'onehot'}
value: {layers: 5}
reward_head: {layers: 5}
cont_head: {layers: 5}
imag_gradient: 'reinforce'
atari100k:
@ -166,7 +148,7 @@ atari100k:
train_ratio: 1024
video_pred_log: true
eval_episode_num: 100
actor_dist: 'onehot'
actor: {dist: 'onehot'}
imag_gradient: 'reinforce'
stickey: False
lives: unused
@ -189,13 +171,12 @@ minecraft:
dyn_hidden: 1024
dyn_deter: 4096
units: 1024
reward_layers: 5
cont_layers: 5
value_layers: 5
actor_layers: 5
encoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath|reward', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
encoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath|obs_reward', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
decoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
actor_dist: 'onehot'
actor: {layers: 5, dist: 'onehot'}
value: {layers: 5}
reward_head: {layers: 5}
cont_head: {layers: 5}
imag_gradient: 'reinforce'
break_speed: 100.0
time_limit: 36000
@ -203,7 +184,7 @@ minecraft:
memorymaze:
steps: 1e8
action_repeat: 2
actor_dist: 'onehot'
actor: {dist: 'onehot'}
imag_gradient: 'reinforce'
task: 'memorymaze_9x9'

View File

@ -110,7 +110,7 @@ class Dreamer(nn.Module):
logprob = actor.log_prob(action)
latent = {k: v.detach() for k, v in latent.items()}
action = action.detach()
if self._config.actor_dist == "onehot_gumble":
if self._config.actor["dist"] == "onehot_gumble":
action = torch.one_hot(
torch.argmax(action, dim=-1), self._config.num_actions
)
@ -123,7 +123,7 @@ class Dreamer(nn.Module):
amount = self._config.expl_amount if training else self._config.eval_noise
if amount == 0:
return action
if "onehot" in self._config.actor_dist:
if "onehot" in self._config.actor["dist"]:
probs = amount / self._config.num_actions + (1 - amount) * action
return tools.OneHotDist(probs=probs).sample()
else:

View File

@ -14,7 +14,7 @@ class Random(nn.Module):
self._act_space = act_space
def actor(self, feat):
if self._config.actor_dist == "onehot":
if self._config.actor["dist"] == "onehot":
return tools.OneHotDist(
torch.zeros(self._config.num_actions)
.repeat(self._config.envs, 1)

145
models.py
View File

@ -67,39 +67,29 @@ class WorldModel(nn.Module):
self.heads["decoder"] = networks.MultiDecoder(
feat_size, shapes, **config.decoder
)
if config.reward_head == "symlog_disc":
self.heads["reward"] = networks.MLP(
feat_size, # pytorch version
(255,),
config.reward_layers,
config.units,
config.act,
config.norm,
dist=config.reward_head,
outscale=0.0,
device=config.device,
)
else:
self.heads["reward"] = networks.MLP(
feat_size, # pytorch version
[],
config.reward_layers,
config.units,
config.act,
config.norm,
dist=config.reward_head,
outscale=0.0,
device=config.device,
)
self.heads["reward"] = networks.MLP(
feat_size,
(255,) if config.reward_head["dist"] == "symlog_disc" else (),
config.reward_head["layers"],
config.units,
config.act,
config.norm,
dist=config.reward_head["dist"],
outscale=config.reward_head["outscale"],
device=config.device,
name="Reward",
)
self.heads["cont"] = networks.MLP(
feat_size, # pytorch version
[],
config.cont_layers,
feat_size,
(),
config.cont_head["layers"],
config.units,
config.act,
config.norm,
dist="binary",
outscale=config.cont_head["outscale"],
device=config.device,
name="Cont",
)
for name in config.grad_heads:
assert name in self.heads, name
@ -113,7 +103,14 @@ class WorldModel(nn.Module):
opt=config.opt,
use_amp=self._use_amp,
)
self._scales = dict(reward=config.reward_scale, cont=config.cont_scale)
print(
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
)
self._scales = dict(
reward=config.reward_head["scale"],
cont=config.cont_head["scale"],
image=1.0,
)
def _train(self, data):
# action (batch_size, batch_length, act_dim)
@ -134,6 +131,7 @@ class WorldModel(nn.Module):
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
post, prior, kl_free, dyn_scale, rep_scale
)
assert kl_loss.shape == embed.shape[:2], kl_loss.shape
preds = {}
for name, head in self.heads.items():
grad_head = name in self._config.grad_heads
@ -226,65 +224,60 @@ class ImagBehavior(nn.Module):
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
else:
feat_size = config.dyn_stoch + config.dyn_deter
self.actor = networks.ActionHead(
self.actor = networks.MLP(
feat_size,
config.num_actions,
config.actor_layers,
(config.num_actions,),
config.actor["layers"],
config.units,
config.act,
config.norm,
config.actor_dist,
config.actor_init_std,
config.actor_min_std,
config.actor_max_std,
config.actor_temp,
config.actor["dist"],
"learned",
config.actor["min_std"],
config.actor["max_std"],
config.actor["temp"],
unimix_ratio=config.actor["unimix_ratio"],
outscale=1.0,
unimix_ratio=config.action_unimix_ratio,
name="Actor",
)
if config.value_head == "symlog_disc":
self.value = networks.MLP(
feat_size,
(255,),
config.value_layers,
config.units,
config.act,
config.norm,
config.value_head,
outscale=0.0,
device=config.device,
)
else:
self.value = networks.MLP(
feat_size,
[],
config.value_layers,
config.units,
config.act,
config.norm,
config.value_head,
outscale=0.0,
device=config.device,
)
if config.slow_value_target:
self.value = networks.MLP(
feat_size,
(255,) if config.critic["dist"] == "symlog_disc" else (),
config.critic["layers"],
config.units,
config.act,
config.norm,
config.critic["dist"],
outscale=config.critic["outscale"],
device=config.device,
name="Value",
)
if config.critic["slow_target"]:
self._slow_value = copy.deepcopy(self.value)
self._updates = 0
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
self._actor_opt = tools.Optimizer(
"actor",
self.actor.parameters(),
config.actor_lr,
config.ac_opt_eps,
config.actor_grad_clip,
config.actor["lr"],
config.actor["eps"],
config.actor["grad_clip"],
**kw,
)
print(
f"Optimizer actor_opt has {sum(param.numel() for param in self.actor.parameters())} variables."
)
self._value_opt = tools.Optimizer(
"value",
self.value.parameters(),
config.value_lr,
config.ac_opt_eps,
config.value_grad_clip,
config.critic["lr"],
config.critic["eps"],
config.critic["grad_clip"],
**kw,
)
print(
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
)
if self._config.reward_EMA:
self.reward_ema = RewardEMA(device=self._config.device)
@ -335,19 +328,15 @@ class ImagBehavior(nn.Module):
# (time, batch, 1), (time, batch, 1) -> (time, batch)
value_loss = -value.log_prob(target.detach())
slow_target = self._slow_value(value_input[:-1].detach())
if self._config.slow_value_target:
value_loss = value_loss - value.log_prob(
slow_target.mode().detach()
)
if self._config.value_decay:
value_loss += self._config.value_decay * value.mode()
if self._config.critic["slow_target"]:
value_loss -= value.log_prob(slow_target.mode().detach())
# (time, batch, 1), (time, batch, 1) -> (1,)
value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])
metrics.update(tools.tensorstats(value.mode(), "value"))
metrics.update(tools.tensorstats(target, "target"))
metrics.update(tools.tensorstats(reward, "imag_reward"))
if self._config.actor_dist in ["onehot"]:
if self._config.actor["dist"] in ["onehot"]:
metrics.update(
tools.tensorstats(
torch.argmax(imag_action, dim=-1).float(), "imag_action"
@ -466,9 +455,9 @@ class ImagBehavior(nn.Module):
return actor_loss, metrics
def _update_slow_target(self):
if self._config.slow_value_target:
if self._updates % self._config.slow_target_update == 0:
mix = self._config.slow_target_fraction
if self._config.critic["slow_target"]:
if self._updates % self._config.critic["slow_target_update"] == 0:
mix = self._config.critic["slow_target_fraction"]
for s, d in zip(self.value.parameters(), self._slow_value.parameters()):
d.data = mix * s.data + (1 - mix) * d.data
self._updates += 1

View File

@ -632,9 +632,14 @@ class MLP(nn.Module):
norm="LayerNorm",
dist="normal",
std=1.0,
min_std=0.1,
max_std=1.0,
temp=0.1,
unimix_ratio=0.01,
outscale=1.0,
symlog_inputs=False,
device="cuda",
name="NoName",
):
super(MLP, self).__init__()
self._shape = (shape,) if isinstance(shape, int) else shape
@ -647,15 +652,20 @@ class MLP(nn.Module):
self._std = std
self._symlog_inputs = symlog_inputs
self._device = device
self._min_std = min_std
self._max_std = max_std
self._temp = temp
self._unimix_ratio = unimix_ratio
layers = []
self.layers = nn.Sequential()
for index in range(self._layers):
layers.append(nn.Linear(inp_dim, units, bias=False))
layers.append(norm(units, eps=1e-03))
layers.append(act())
self.layers.add_module(
f"{name}_linear{index}", nn.Linear(inp_dim, units, bias=False)
)
self.layers.add_module(f"{name}_norm{index}", norm(units, eps=1e-03))
self.layers.add_module(f"{name}_act{index}", act())
if index == 0:
inp_dim = units
self.layers = nn.Sequential(*layers)
self.layers.apply(tools.weight_init)
if isinstance(self._shape, dict):
@ -664,6 +674,7 @@ class MLP(nn.Module):
self.mean_layer[name] = nn.Linear(inp_dim, np.prod(shape))
self.mean_layer.apply(tools.uniform_weight_init(outscale))
if self._std == "learned":
assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist
self.std_layer = nn.ModuleDict()
for name, shape in self._shape.items():
self.std_layer[name] = nn.Linear(inp_dim, np.prod(shape))
@ -672,6 +683,7 @@ class MLP(nn.Module):
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
self.mean_layer.apply(tools.uniform_weight_init(outscale))
if self._std == "learned":
assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist
self.std_layer = nn.Linear(units, np.prod(self._shape))
self.std_layer.apply(tools.uniform_weight_init(outscale))
@ -680,6 +692,7 @@ class MLP(nn.Module):
if self._symlog_inputs:
x = tools.symlog(x)
out = self.layers(x)
# Used for encoder output
if self._shape is None:
return out
if isinstance(self._shape, dict):
@ -701,98 +714,9 @@ class MLP(nn.Module):
return self.dist(self._dist, mean, std, self._shape)
def dist(self, dist, mean, std, shape):
if dist == "normal":
return tools.ContDist(
torchd.independent.Independent(
torchd.normal.Normal(mean, std), len(shape)
)
)
if dist == "huber":
return tools.ContDist(
torchd.independent.Independent(
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
)
)
if dist == "binary":
return tools.Bernoulli(
torchd.independent.Independent(
torchd.bernoulli.Bernoulli(logits=mean), len(shape)
)
)
if dist == "symlog_disc":
return tools.DiscDist(logits=mean, device=self._device)
if dist == "symlog_mse":
return tools.SymlogDist(mean)
raise NotImplementedError(dist)
class ActionHead(nn.Module):
def __init__(
self,
inp_dim,
size,
layers,
units,
act=nn.ELU,
norm=nn.LayerNorm,
dist="trunc_normal",
init_std=0.0,
min_std=0.1,
max_std=1.0,
temp=0.1,
outscale=1.0,
unimix_ratio=0.01,
):
super(ActionHead, self).__init__()
self._size = size
self._layers = layers
self._units = units
self._dist = dist
act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
self._min_std = min_std
self._max_std = max_std
self._init_std = init_std
self._unimix_ratio = unimix_ratio
self._temp = temp() if callable(temp) else temp
pre_layers = []
for index in range(self._layers):
pre_layers.append(nn.Linear(inp_dim, self._units, bias=False))
pre_layers.append(norm(self._units, eps=1e-03))
pre_layers.append(act())
if index == 0:
inp_dim = self._units
self._pre_layers = nn.Sequential(*pre_layers)
self._pre_layers.apply(tools.weight_init)
if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]:
self._dist_layer = nn.Linear(self._units, 2 * self._size)
self._dist_layer.apply(tools.uniform_weight_init(outscale))
elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]:
self._dist_layer = nn.Linear(self._units, self._size)
self._dist_layer.apply(tools.uniform_weight_init(outscale))
def forward(self, features, dtype=None):
x = features
x = self._pre_layers(x)
if self._dist == "tanh_normal":
x = self._dist_layer(x)
mean, std = torch.split(x, 2, -1)
mean = torch.tanh(mean)
std = F.softplus(std + self._init_std) + self._min_std
dist = torchd.normal.Normal(mean, std)
dist = torchd.transformed_distribution.TransformedDistribution(
dist, tools.TanhBijector()
)
dist = torchd.independent.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == "tanh_normal_5":
x = self._dist_layer(x)
mean, std = torch.split(x, 2, -1)
mean = 5 * torch.tanh(mean / 5)
std = F.softplus(std + 5) + 5
std = F.softplus(std) + self._min_std
dist = torchd.normal.Normal(mean, std)
dist = torchd.transformed_distribution.TransformedDistribution(
dist, tools.TanhBijector()
@ -800,33 +724,41 @@ class ActionHead(nn.Module):
dist = torchd.independent.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == "normal":
x = self._dist_layer(x)
mean, std = torch.split(x, [self._size] * 2, -1)
std = (self._max_std - self._min_std) * torch.sigmoid(
std + 2.0
) + self._min_std
dist = torchd.normal.Normal(torch.tanh(mean), std)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "normal_1":
mean = self._dist_layer(x)
dist = torchd.normal.Normal(mean, 1)
dist = tools.ContDist(torchd.independent.Independent(dist, 1), absmax=1.0)
elif self._dist == "normal_std_fixed":
dist = torchd.normal.Normal(mean, self._std)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "trunc_normal":
x = self._dist_layer(x)
mean, std = torch.split(x, [self._size] * 2, -1)
mean = torch.tanh(mean)
std = 2 * torch.sigmoid(std / 2) + self._min_std
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "onehot":
x = self._dist_layer(x)
dist = tools.OneHotDist(x, unimix_ratio=self._unimix_ratio)
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
elif self._dist == "onehot_gumble":
x = self._dist_layer(x)
temp = self._temp
dist = tools.ContDist(torchd.gumbel.Gumbel(x, 1 / temp))
dist = tools.ContDist(torchd.gumbel.Gumbel(mean, 1 / self._temp))
elif dist == "huber":
dist = tools.ContDist(
torchd.independent.Independent(
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
)
)
elif dist == "binary":
dist = tools.Bernoulli(
torchd.independent.Independent(
torchd.bernoulli.Bernoulli(logits=mean), len(shape)
)
)
elif dist == "symlog_disc":
dist = tools.DiscDist(logits=mean, device=self._device)
elif dist == "symlog_mse":
dist = tools.SymlogDist(mean)
else:
raise NotImplementedError(self._dist)
raise NotImplementedError(dist)
return dist