merged action head into MLP and modified configs
This commit is contained in:
parent
e0f2017e28
commit
e0487f8206
63
configs.yaml
63
configs.yaml
@ -47,26 +47,25 @@ defaults:
|
|||||||
dyn_temp_post: True
|
dyn_temp_post: True
|
||||||
grad_heads: ['decoder', 'reward', 'cont']
|
grad_heads: ['decoder', 'reward', 'cont']
|
||||||
units: 512
|
units: 512
|
||||||
reward_layers: 2
|
|
||||||
cont_layers: 2
|
|
||||||
value_layers: 2
|
|
||||||
actor_layers: 2
|
|
||||||
act: 'SiLU'
|
act: 'SiLU'
|
||||||
norm: 'LayerNorm'
|
norm: 'LayerNorm'
|
||||||
encoder:
|
encoder:
|
||||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
|
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
|
||||||
decoder:
|
decoder:
|
||||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse}
|
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse, outscale: 1.0}
|
||||||
value_head: 'symlog_disc'
|
actor:
|
||||||
reward_head: 'symlog_disc'
|
{layers: 2, dist: 'normal', entropy: 3e-4, unimix_ratio: 0.01, min_std: 0.1, max_std: 1.0, temp: 0.1, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 1.0}
|
||||||
|
critic:
|
||||||
|
{layers: 2, dist: 'symlog_disc', slow_target: True, slow_target_update: 1, slow_target_fraction: 0.02, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 0.0}
|
||||||
|
reward_head:
|
||||||
|
{layers: 2, dist: 'symlog_disc', scale: 1.0, outscale: 0.0}
|
||||||
|
cont_head:
|
||||||
|
{layers: 2, scale: 1.0, outscale: 1.0}
|
||||||
dyn_scale: 0.5
|
dyn_scale: 0.5
|
||||||
rep_scale: 0.1
|
rep_scale: 0.1
|
||||||
kl_free: 1.0
|
kl_free: 1.0
|
||||||
cont_scale: 1.0
|
|
||||||
reward_scale: 1.0
|
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
unimix_ratio: 0.01
|
unimix_ratio: 0.01
|
||||||
action_unimix_ratio: 0.01
|
|
||||||
initial: 'learned'
|
initial: 'learned'
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
@ -77,15 +76,7 @@ defaults:
|
|||||||
model_lr: 1e-4
|
model_lr: 1e-4
|
||||||
opt_eps: 1e-8
|
opt_eps: 1e-8
|
||||||
grad_clip: 1000
|
grad_clip: 1000
|
||||||
value_lr: 3e-5
|
|
||||||
actor_lr: 3e-5
|
|
||||||
ac_opt_eps: 1e-5
|
|
||||||
value_grad_clip: 100
|
|
||||||
actor_grad_clip: 100
|
|
||||||
dataset_size: 1000000
|
dataset_size: 1000000
|
||||||
slow_value_target: True
|
|
||||||
slow_target_update: 1
|
|
||||||
slow_target_fraction: 0.02
|
|
||||||
opt: 'adam'
|
opt: 'adam'
|
||||||
|
|
||||||
# Behavior.
|
# Behavior.
|
||||||
@ -95,18 +86,10 @@ defaults:
|
|||||||
imag_gradient: 'dynamics'
|
imag_gradient: 'dynamics'
|
||||||
imag_gradient_mix: 0.0
|
imag_gradient_mix: 0.0
|
||||||
imag_sample: True
|
imag_sample: True
|
||||||
actor_dist: 'normal'
|
expl_amount: 0
|
||||||
actor_entropy: 3e-4
|
|
||||||
actor_state_entropy: 0.0
|
|
||||||
actor_init_std: 1.0
|
|
||||||
actor_min_std: 0.1
|
|
||||||
actor_max_std: 1.0
|
|
||||||
actor_temp: 0.1
|
|
||||||
expl_amount: 0.0
|
|
||||||
eval_state_mean: False
|
eval_state_mean: False
|
||||||
collect_dyn_sample: True
|
collect_dyn_sample: True
|
||||||
behavior_stop_grad: True
|
behavior_stop_grad: True
|
||||||
value_decay: 0.0
|
|
||||||
future_entropy: False
|
future_entropy: False
|
||||||
|
|
||||||
# Exploration
|
# Exploration
|
||||||
@ -150,13 +133,12 @@ crafter:
|
|||||||
dyn_hidden: 1024
|
dyn_hidden: 1024
|
||||||
dyn_deter: 4096
|
dyn_deter: 4096
|
||||||
units: 1024
|
units: 1024
|
||||||
reward_layers: 5
|
|
||||||
cont_layers: 5
|
|
||||||
value_layers: 5
|
|
||||||
actor_layers: 5
|
|
||||||
encoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
encoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
||||||
decoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
decoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
||||||
actor_dist: 'onehot'
|
actor: {layers: 5, dist: 'onehot'}
|
||||||
|
value: {layers: 5}
|
||||||
|
reward_head: {layers: 5}
|
||||||
|
cont_head: {layers: 5}
|
||||||
imag_gradient: 'reinforce'
|
imag_gradient: 'reinforce'
|
||||||
|
|
||||||
atari100k:
|
atari100k:
|
||||||
@ -166,7 +148,7 @@ atari100k:
|
|||||||
train_ratio: 1024
|
train_ratio: 1024
|
||||||
video_pred_log: true
|
video_pred_log: true
|
||||||
eval_episode_num: 100
|
eval_episode_num: 100
|
||||||
actor_dist: 'onehot'
|
actor: {dist: 'onehot'}
|
||||||
imag_gradient: 'reinforce'
|
imag_gradient: 'reinforce'
|
||||||
stickey: False
|
stickey: False
|
||||||
lives: unused
|
lives: unused
|
||||||
@ -189,13 +171,12 @@ minecraft:
|
|||||||
dyn_hidden: 1024
|
dyn_hidden: 1024
|
||||||
dyn_deter: 4096
|
dyn_deter: 4096
|
||||||
units: 1024
|
units: 1024
|
||||||
reward_layers: 5
|
encoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath|obs_reward', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
||||||
cont_layers: 5
|
|
||||||
value_layers: 5
|
|
||||||
actor_layers: 5
|
|
||||||
encoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath|reward', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
|
||||||
decoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
decoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
||||||
actor_dist: 'onehot'
|
actor: {layers: 5, dist: 'onehot'}
|
||||||
|
value: {layers: 5}
|
||||||
|
reward_head: {layers: 5}
|
||||||
|
cont_head: {layers: 5}
|
||||||
imag_gradient: 'reinforce'
|
imag_gradient: 'reinforce'
|
||||||
break_speed: 100.0
|
break_speed: 100.0
|
||||||
time_limit: 36000
|
time_limit: 36000
|
||||||
@ -203,7 +184,7 @@ minecraft:
|
|||||||
memorymaze:
|
memorymaze:
|
||||||
steps: 1e8
|
steps: 1e8
|
||||||
action_repeat: 2
|
action_repeat: 2
|
||||||
actor_dist: 'onehot'
|
actor: {dist: 'onehot'}
|
||||||
imag_gradient: 'reinforce'
|
imag_gradient: 'reinforce'
|
||||||
task: 'memorymaze_9x9'
|
task: 'memorymaze_9x9'
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ class Dreamer(nn.Module):
|
|||||||
logprob = actor.log_prob(action)
|
logprob = actor.log_prob(action)
|
||||||
latent = {k: v.detach() for k, v in latent.items()}
|
latent = {k: v.detach() for k, v in latent.items()}
|
||||||
action = action.detach()
|
action = action.detach()
|
||||||
if self._config.actor_dist == "onehot_gumble":
|
if self._config.actor["dist"] == "onehot_gumble":
|
||||||
action = torch.one_hot(
|
action = torch.one_hot(
|
||||||
torch.argmax(action, dim=-1), self._config.num_actions
|
torch.argmax(action, dim=-1), self._config.num_actions
|
||||||
)
|
)
|
||||||
@ -123,7 +123,7 @@ class Dreamer(nn.Module):
|
|||||||
amount = self._config.expl_amount if training else self._config.eval_noise
|
amount = self._config.expl_amount if training else self._config.eval_noise
|
||||||
if amount == 0:
|
if amount == 0:
|
||||||
return action
|
return action
|
||||||
if "onehot" in self._config.actor_dist:
|
if "onehot" in self._config.actor["dist"]:
|
||||||
probs = amount / self._config.num_actions + (1 - amount) * action
|
probs = amount / self._config.num_actions + (1 - amount) * action
|
||||||
return tools.OneHotDist(probs=probs).sample()
|
return tools.OneHotDist(probs=probs).sample()
|
||||||
else:
|
else:
|
||||||
|
@ -14,7 +14,7 @@ class Random(nn.Module):
|
|||||||
self._act_space = act_space
|
self._act_space = act_space
|
||||||
|
|
||||||
def actor(self, feat):
|
def actor(self, feat):
|
||||||
if self._config.actor_dist == "onehot":
|
if self._config.actor["dist"] == "onehot":
|
||||||
return tools.OneHotDist(
|
return tools.OneHotDist(
|
||||||
torch.zeros(self._config.num_actions)
|
torch.zeros(self._config.num_actions)
|
||||||
.repeat(self._config.envs, 1)
|
.repeat(self._config.envs, 1)
|
||||||
|
145
models.py
145
models.py
@ -67,39 +67,29 @@ class WorldModel(nn.Module):
|
|||||||
self.heads["decoder"] = networks.MultiDecoder(
|
self.heads["decoder"] = networks.MultiDecoder(
|
||||||
feat_size, shapes, **config.decoder
|
feat_size, shapes, **config.decoder
|
||||||
)
|
)
|
||||||
if config.reward_head == "symlog_disc":
|
self.heads["reward"] = networks.MLP(
|
||||||
self.heads["reward"] = networks.MLP(
|
feat_size,
|
||||||
feat_size, # pytorch version
|
(255,) if config.reward_head["dist"] == "symlog_disc" else (),
|
||||||
(255,),
|
config.reward_head["layers"],
|
||||||
config.reward_layers,
|
config.units,
|
||||||
config.units,
|
config.act,
|
||||||
config.act,
|
config.norm,
|
||||||
config.norm,
|
dist=config.reward_head["dist"],
|
||||||
dist=config.reward_head,
|
outscale=config.reward_head["outscale"],
|
||||||
outscale=0.0,
|
device=config.device,
|
||||||
device=config.device,
|
name="Reward",
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.heads["reward"] = networks.MLP(
|
|
||||||
feat_size, # pytorch version
|
|
||||||
[],
|
|
||||||
config.reward_layers,
|
|
||||||
config.units,
|
|
||||||
config.act,
|
|
||||||
config.norm,
|
|
||||||
dist=config.reward_head,
|
|
||||||
outscale=0.0,
|
|
||||||
device=config.device,
|
|
||||||
)
|
|
||||||
self.heads["cont"] = networks.MLP(
|
self.heads["cont"] = networks.MLP(
|
||||||
feat_size, # pytorch version
|
feat_size,
|
||||||
[],
|
(),
|
||||||
config.cont_layers,
|
config.cont_head["layers"],
|
||||||
config.units,
|
config.units,
|
||||||
config.act,
|
config.act,
|
||||||
config.norm,
|
config.norm,
|
||||||
dist="binary",
|
dist="binary",
|
||||||
|
outscale=config.cont_head["outscale"],
|
||||||
device=config.device,
|
device=config.device,
|
||||||
|
name="Cont",
|
||||||
)
|
)
|
||||||
for name in config.grad_heads:
|
for name in config.grad_heads:
|
||||||
assert name in self.heads, name
|
assert name in self.heads, name
|
||||||
@ -113,7 +103,14 @@ class WorldModel(nn.Module):
|
|||||||
opt=config.opt,
|
opt=config.opt,
|
||||||
use_amp=self._use_amp,
|
use_amp=self._use_amp,
|
||||||
)
|
)
|
||||||
self._scales = dict(reward=config.reward_scale, cont=config.cont_scale)
|
print(
|
||||||
|
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
|
||||||
|
)
|
||||||
|
self._scales = dict(
|
||||||
|
reward=config.reward_head["scale"],
|
||||||
|
cont=config.cont_head["scale"],
|
||||||
|
image=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
def _train(self, data):
|
def _train(self, data):
|
||||||
# action (batch_size, batch_length, act_dim)
|
# action (batch_size, batch_length, act_dim)
|
||||||
@ -134,6 +131,7 @@ class WorldModel(nn.Module):
|
|||||||
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
|
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
|
||||||
post, prior, kl_free, dyn_scale, rep_scale
|
post, prior, kl_free, dyn_scale, rep_scale
|
||||||
)
|
)
|
||||||
|
assert kl_loss.shape == embed.shape[:2], kl_loss.shape
|
||||||
preds = {}
|
preds = {}
|
||||||
for name, head in self.heads.items():
|
for name, head in self.heads.items():
|
||||||
grad_head = name in self._config.grad_heads
|
grad_head = name in self._config.grad_heads
|
||||||
@ -226,65 +224,60 @@ class ImagBehavior(nn.Module):
|
|||||||
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
|
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
|
||||||
else:
|
else:
|
||||||
feat_size = config.dyn_stoch + config.dyn_deter
|
feat_size = config.dyn_stoch + config.dyn_deter
|
||||||
self.actor = networks.ActionHead(
|
self.actor = networks.MLP(
|
||||||
feat_size,
|
feat_size,
|
||||||
config.num_actions,
|
(config.num_actions,),
|
||||||
config.actor_layers,
|
config.actor["layers"],
|
||||||
config.units,
|
config.units,
|
||||||
config.act,
|
config.act,
|
||||||
config.norm,
|
config.norm,
|
||||||
config.actor_dist,
|
config.actor["dist"],
|
||||||
config.actor_init_std,
|
"learned",
|
||||||
config.actor_min_std,
|
config.actor["min_std"],
|
||||||
config.actor_max_std,
|
config.actor["max_std"],
|
||||||
config.actor_temp,
|
config.actor["temp"],
|
||||||
|
unimix_ratio=config.actor["unimix_ratio"],
|
||||||
outscale=1.0,
|
outscale=1.0,
|
||||||
unimix_ratio=config.action_unimix_ratio,
|
name="Actor",
|
||||||
)
|
)
|
||||||
if config.value_head == "symlog_disc":
|
self.value = networks.MLP(
|
||||||
self.value = networks.MLP(
|
feat_size,
|
||||||
feat_size,
|
(255,) if config.critic["dist"] == "symlog_disc" else (),
|
||||||
(255,),
|
config.critic["layers"],
|
||||||
config.value_layers,
|
config.units,
|
||||||
config.units,
|
config.act,
|
||||||
config.act,
|
config.norm,
|
||||||
config.norm,
|
config.critic["dist"],
|
||||||
config.value_head,
|
outscale=config.critic["outscale"],
|
||||||
outscale=0.0,
|
device=config.device,
|
||||||
device=config.device,
|
name="Value",
|
||||||
)
|
)
|
||||||
else:
|
if config.critic["slow_target"]:
|
||||||
self.value = networks.MLP(
|
|
||||||
feat_size,
|
|
||||||
[],
|
|
||||||
config.value_layers,
|
|
||||||
config.units,
|
|
||||||
config.act,
|
|
||||||
config.norm,
|
|
||||||
config.value_head,
|
|
||||||
outscale=0.0,
|
|
||||||
device=config.device,
|
|
||||||
)
|
|
||||||
if config.slow_value_target:
|
|
||||||
self._slow_value = copy.deepcopy(self.value)
|
self._slow_value = copy.deepcopy(self.value)
|
||||||
self._updates = 0
|
self._updates = 0
|
||||||
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
|
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
|
||||||
self._actor_opt = tools.Optimizer(
|
self._actor_opt = tools.Optimizer(
|
||||||
"actor",
|
"actor",
|
||||||
self.actor.parameters(),
|
self.actor.parameters(),
|
||||||
config.actor_lr,
|
config.actor["lr"],
|
||||||
config.ac_opt_eps,
|
config.actor["eps"],
|
||||||
config.actor_grad_clip,
|
config.actor["grad_clip"],
|
||||||
**kw,
|
**kw,
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
f"Optimizer actor_opt has {sum(param.numel() for param in self.actor.parameters())} variables."
|
||||||
|
)
|
||||||
self._value_opt = tools.Optimizer(
|
self._value_opt = tools.Optimizer(
|
||||||
"value",
|
"value",
|
||||||
self.value.parameters(),
|
self.value.parameters(),
|
||||||
config.value_lr,
|
config.critic["lr"],
|
||||||
config.ac_opt_eps,
|
config.critic["eps"],
|
||||||
config.value_grad_clip,
|
config.critic["grad_clip"],
|
||||||
**kw,
|
**kw,
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
|
||||||
|
)
|
||||||
if self._config.reward_EMA:
|
if self._config.reward_EMA:
|
||||||
self.reward_ema = RewardEMA(device=self._config.device)
|
self.reward_ema = RewardEMA(device=self._config.device)
|
||||||
|
|
||||||
@ -335,19 +328,15 @@ class ImagBehavior(nn.Module):
|
|||||||
# (time, batch, 1), (time, batch, 1) -> (time, batch)
|
# (time, batch, 1), (time, batch, 1) -> (time, batch)
|
||||||
value_loss = -value.log_prob(target.detach())
|
value_loss = -value.log_prob(target.detach())
|
||||||
slow_target = self._slow_value(value_input[:-1].detach())
|
slow_target = self._slow_value(value_input[:-1].detach())
|
||||||
if self._config.slow_value_target:
|
if self._config.critic["slow_target"]:
|
||||||
value_loss = value_loss - value.log_prob(
|
value_loss -= value.log_prob(slow_target.mode().detach())
|
||||||
slow_target.mode().detach()
|
|
||||||
)
|
|
||||||
if self._config.value_decay:
|
|
||||||
value_loss += self._config.value_decay * value.mode()
|
|
||||||
# (time, batch, 1), (time, batch, 1) -> (1,)
|
# (time, batch, 1), (time, batch, 1) -> (1,)
|
||||||
value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])
|
value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])
|
||||||
|
|
||||||
metrics.update(tools.tensorstats(value.mode(), "value"))
|
metrics.update(tools.tensorstats(value.mode(), "value"))
|
||||||
metrics.update(tools.tensorstats(target, "target"))
|
metrics.update(tools.tensorstats(target, "target"))
|
||||||
metrics.update(tools.tensorstats(reward, "imag_reward"))
|
metrics.update(tools.tensorstats(reward, "imag_reward"))
|
||||||
if self._config.actor_dist in ["onehot"]:
|
if self._config.actor["dist"] in ["onehot"]:
|
||||||
metrics.update(
|
metrics.update(
|
||||||
tools.tensorstats(
|
tools.tensorstats(
|
||||||
torch.argmax(imag_action, dim=-1).float(), "imag_action"
|
torch.argmax(imag_action, dim=-1).float(), "imag_action"
|
||||||
@ -466,9 +455,9 @@ class ImagBehavior(nn.Module):
|
|||||||
return actor_loss, metrics
|
return actor_loss, metrics
|
||||||
|
|
||||||
def _update_slow_target(self):
|
def _update_slow_target(self):
|
||||||
if self._config.slow_value_target:
|
if self._config.critic["slow_target"]:
|
||||||
if self._updates % self._config.slow_target_update == 0:
|
if self._updates % self._config.critic["slow_target_update"] == 0:
|
||||||
mix = self._config.slow_target_fraction
|
mix = self._config.critic["slow_target_fraction"]
|
||||||
for s, d in zip(self.value.parameters(), self._slow_value.parameters()):
|
for s, d in zip(self.value.parameters(), self._slow_value.parameters()):
|
||||||
d.data = mix * s.data + (1 - mix) * d.data
|
d.data = mix * s.data + (1 - mix) * d.data
|
||||||
self._updates += 1
|
self._updates += 1
|
||||||
|
150
networks.py
150
networks.py
@ -632,9 +632,14 @@ class MLP(nn.Module):
|
|||||||
norm="LayerNorm",
|
norm="LayerNorm",
|
||||||
dist="normal",
|
dist="normal",
|
||||||
std=1.0,
|
std=1.0,
|
||||||
|
min_std=0.1,
|
||||||
|
max_std=1.0,
|
||||||
|
temp=0.1,
|
||||||
|
unimix_ratio=0.01,
|
||||||
outscale=1.0,
|
outscale=1.0,
|
||||||
symlog_inputs=False,
|
symlog_inputs=False,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
|
name="NoName",
|
||||||
):
|
):
|
||||||
super(MLP, self).__init__()
|
super(MLP, self).__init__()
|
||||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||||
@ -647,15 +652,20 @@ class MLP(nn.Module):
|
|||||||
self._std = std
|
self._std = std
|
||||||
self._symlog_inputs = symlog_inputs
|
self._symlog_inputs = symlog_inputs
|
||||||
self._device = device
|
self._device = device
|
||||||
|
self._min_std = min_std
|
||||||
|
self._max_std = max_std
|
||||||
|
self._temp = temp
|
||||||
|
self._unimix_ratio = unimix_ratio
|
||||||
|
|
||||||
layers = []
|
self.layers = nn.Sequential()
|
||||||
for index in range(self._layers):
|
for index in range(self._layers):
|
||||||
layers.append(nn.Linear(inp_dim, units, bias=False))
|
self.layers.add_module(
|
||||||
layers.append(norm(units, eps=1e-03))
|
f"{name}_linear{index}", nn.Linear(inp_dim, units, bias=False)
|
||||||
layers.append(act())
|
)
|
||||||
|
self.layers.add_module(f"{name}_norm{index}", norm(units, eps=1e-03))
|
||||||
|
self.layers.add_module(f"{name}_act{index}", act())
|
||||||
if index == 0:
|
if index == 0:
|
||||||
inp_dim = units
|
inp_dim = units
|
||||||
self.layers = nn.Sequential(*layers)
|
|
||||||
self.layers.apply(tools.weight_init)
|
self.layers.apply(tools.weight_init)
|
||||||
|
|
||||||
if isinstance(self._shape, dict):
|
if isinstance(self._shape, dict):
|
||||||
@ -664,6 +674,7 @@ class MLP(nn.Module):
|
|||||||
self.mean_layer[name] = nn.Linear(inp_dim, np.prod(shape))
|
self.mean_layer[name] = nn.Linear(inp_dim, np.prod(shape))
|
||||||
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
||||||
if self._std == "learned":
|
if self._std == "learned":
|
||||||
|
assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist
|
||||||
self.std_layer = nn.ModuleDict()
|
self.std_layer = nn.ModuleDict()
|
||||||
for name, shape in self._shape.items():
|
for name, shape in self._shape.items():
|
||||||
self.std_layer[name] = nn.Linear(inp_dim, np.prod(shape))
|
self.std_layer[name] = nn.Linear(inp_dim, np.prod(shape))
|
||||||
@ -672,6 +683,7 @@ class MLP(nn.Module):
|
|||||||
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
|
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
|
||||||
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
||||||
if self._std == "learned":
|
if self._std == "learned":
|
||||||
|
assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist
|
||||||
self.std_layer = nn.Linear(units, np.prod(self._shape))
|
self.std_layer = nn.Linear(units, np.prod(self._shape))
|
||||||
self.std_layer.apply(tools.uniform_weight_init(outscale))
|
self.std_layer.apply(tools.uniform_weight_init(outscale))
|
||||||
|
|
||||||
@ -680,6 +692,7 @@ class MLP(nn.Module):
|
|||||||
if self._symlog_inputs:
|
if self._symlog_inputs:
|
||||||
x = tools.symlog(x)
|
x = tools.symlog(x)
|
||||||
out = self.layers(x)
|
out = self.layers(x)
|
||||||
|
# Used for encoder output
|
||||||
if self._shape is None:
|
if self._shape is None:
|
||||||
return out
|
return out
|
||||||
if isinstance(self._shape, dict):
|
if isinstance(self._shape, dict):
|
||||||
@ -701,98 +714,9 @@ class MLP(nn.Module):
|
|||||||
return self.dist(self._dist, mean, std, self._shape)
|
return self.dist(self._dist, mean, std, self._shape)
|
||||||
|
|
||||||
def dist(self, dist, mean, std, shape):
|
def dist(self, dist, mean, std, shape):
|
||||||
if dist == "normal":
|
|
||||||
return tools.ContDist(
|
|
||||||
torchd.independent.Independent(
|
|
||||||
torchd.normal.Normal(mean, std), len(shape)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if dist == "huber":
|
|
||||||
return tools.ContDist(
|
|
||||||
torchd.independent.Independent(
|
|
||||||
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if dist == "binary":
|
|
||||||
return tools.Bernoulli(
|
|
||||||
torchd.independent.Independent(
|
|
||||||
torchd.bernoulli.Bernoulli(logits=mean), len(shape)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if dist == "symlog_disc":
|
|
||||||
return tools.DiscDist(logits=mean, device=self._device)
|
|
||||||
if dist == "symlog_mse":
|
|
||||||
return tools.SymlogDist(mean)
|
|
||||||
raise NotImplementedError(dist)
|
|
||||||
|
|
||||||
|
|
||||||
class ActionHead(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
inp_dim,
|
|
||||||
size,
|
|
||||||
layers,
|
|
||||||
units,
|
|
||||||
act=nn.ELU,
|
|
||||||
norm=nn.LayerNorm,
|
|
||||||
dist="trunc_normal",
|
|
||||||
init_std=0.0,
|
|
||||||
min_std=0.1,
|
|
||||||
max_std=1.0,
|
|
||||||
temp=0.1,
|
|
||||||
outscale=1.0,
|
|
||||||
unimix_ratio=0.01,
|
|
||||||
):
|
|
||||||
super(ActionHead, self).__init__()
|
|
||||||
self._size = size
|
|
||||||
self._layers = layers
|
|
||||||
self._units = units
|
|
||||||
self._dist = dist
|
|
||||||
act = getattr(torch.nn, act)
|
|
||||||
norm = getattr(torch.nn, norm)
|
|
||||||
self._min_std = min_std
|
|
||||||
self._max_std = max_std
|
|
||||||
self._init_std = init_std
|
|
||||||
self._unimix_ratio = unimix_ratio
|
|
||||||
self._temp = temp() if callable(temp) else temp
|
|
||||||
|
|
||||||
pre_layers = []
|
|
||||||
for index in range(self._layers):
|
|
||||||
pre_layers.append(nn.Linear(inp_dim, self._units, bias=False))
|
|
||||||
pre_layers.append(norm(self._units, eps=1e-03))
|
|
||||||
pre_layers.append(act())
|
|
||||||
if index == 0:
|
|
||||||
inp_dim = self._units
|
|
||||||
self._pre_layers = nn.Sequential(*pre_layers)
|
|
||||||
self._pre_layers.apply(tools.weight_init)
|
|
||||||
|
|
||||||
if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]:
|
|
||||||
self._dist_layer = nn.Linear(self._units, 2 * self._size)
|
|
||||||
self._dist_layer.apply(tools.uniform_weight_init(outscale))
|
|
||||||
|
|
||||||
elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]:
|
|
||||||
self._dist_layer = nn.Linear(self._units, self._size)
|
|
||||||
self._dist_layer.apply(tools.uniform_weight_init(outscale))
|
|
||||||
|
|
||||||
def forward(self, features, dtype=None):
|
|
||||||
x = features
|
|
||||||
x = self._pre_layers(x)
|
|
||||||
if self._dist == "tanh_normal":
|
if self._dist == "tanh_normal":
|
||||||
x = self._dist_layer(x)
|
|
||||||
mean, std = torch.split(x, 2, -1)
|
|
||||||
mean = torch.tanh(mean)
|
mean = torch.tanh(mean)
|
||||||
std = F.softplus(std + self._init_std) + self._min_std
|
std = F.softplus(std) + self._min_std
|
||||||
dist = torchd.normal.Normal(mean, std)
|
|
||||||
dist = torchd.transformed_distribution.TransformedDistribution(
|
|
||||||
dist, tools.TanhBijector()
|
|
||||||
)
|
|
||||||
dist = torchd.independent.Independent(dist, 1)
|
|
||||||
dist = tools.SampleDist(dist)
|
|
||||||
elif self._dist == "tanh_normal_5":
|
|
||||||
x = self._dist_layer(x)
|
|
||||||
mean, std = torch.split(x, 2, -1)
|
|
||||||
mean = 5 * torch.tanh(mean / 5)
|
|
||||||
std = F.softplus(std + 5) + 5
|
|
||||||
dist = torchd.normal.Normal(mean, std)
|
dist = torchd.normal.Normal(mean, std)
|
||||||
dist = torchd.transformed_distribution.TransformedDistribution(
|
dist = torchd.transformed_distribution.TransformedDistribution(
|
||||||
dist, tools.TanhBijector()
|
dist, tools.TanhBijector()
|
||||||
@ -800,33 +724,41 @@ class ActionHead(nn.Module):
|
|||||||
dist = torchd.independent.Independent(dist, 1)
|
dist = torchd.independent.Independent(dist, 1)
|
||||||
dist = tools.SampleDist(dist)
|
dist = tools.SampleDist(dist)
|
||||||
elif self._dist == "normal":
|
elif self._dist == "normal":
|
||||||
x = self._dist_layer(x)
|
|
||||||
mean, std = torch.split(x, [self._size] * 2, -1)
|
|
||||||
std = (self._max_std - self._min_std) * torch.sigmoid(
|
std = (self._max_std - self._min_std) * torch.sigmoid(
|
||||||
std + 2.0
|
std + 2.0
|
||||||
) + self._min_std
|
) + self._min_std
|
||||||
dist = torchd.normal.Normal(torch.tanh(mean), std)
|
dist = torchd.normal.Normal(torch.tanh(mean), std)
|
||||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
dist = tools.ContDist(torchd.independent.Independent(dist, 1), absmax=1.0)
|
||||||
elif self._dist == "normal_1":
|
elif self._dist == "normal_std_fixed":
|
||||||
mean = self._dist_layer(x)
|
dist = torchd.normal.Normal(mean, self._std)
|
||||||
dist = torchd.normal.Normal(mean, 1)
|
|
||||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||||
elif self._dist == "trunc_normal":
|
elif self._dist == "trunc_normal":
|
||||||
x = self._dist_layer(x)
|
|
||||||
mean, std = torch.split(x, [self._size] * 2, -1)
|
|
||||||
mean = torch.tanh(mean)
|
mean = torch.tanh(mean)
|
||||||
std = 2 * torch.sigmoid(std / 2) + self._min_std
|
std = 2 * torch.sigmoid(std / 2) + self._min_std
|
||||||
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
||||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||||
elif self._dist == "onehot":
|
elif self._dist == "onehot":
|
||||||
x = self._dist_layer(x)
|
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
|
||||||
dist = tools.OneHotDist(x, unimix_ratio=self._unimix_ratio)
|
|
||||||
elif self._dist == "onehot_gumble":
|
elif self._dist == "onehot_gumble":
|
||||||
x = self._dist_layer(x)
|
dist = tools.ContDist(torchd.gumbel.Gumbel(mean, 1 / self._temp))
|
||||||
temp = self._temp
|
elif dist == "huber":
|
||||||
dist = tools.ContDist(torchd.gumbel.Gumbel(x, 1 / temp))
|
dist = tools.ContDist(
|
||||||
|
torchd.independent.Independent(
|
||||||
|
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif dist == "binary":
|
||||||
|
dist = tools.Bernoulli(
|
||||||
|
torchd.independent.Independent(
|
||||||
|
torchd.bernoulli.Bernoulli(logits=mean), len(shape)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif dist == "symlog_disc":
|
||||||
|
dist = tools.DiscDist(logits=mean, device=self._device)
|
||||||
|
elif dist == "symlog_mse":
|
||||||
|
dist = tools.SymlogDist(mean)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self._dist)
|
raise NotImplementedError(dist)
|
||||||
return dist
|
return dist
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user