merged action head into MLP and modified configs
This commit is contained in:
parent
e0f2017e28
commit
e0487f8206
63
configs.yaml
63
configs.yaml
@ -47,26 +47,25 @@ defaults:
|
||||
dyn_temp_post: True
|
||||
grad_heads: ['decoder', 'reward', 'cont']
|
||||
units: 512
|
||||
reward_layers: 2
|
||||
cont_layers: 2
|
||||
value_layers: 2
|
||||
actor_layers: 2
|
||||
act: 'SiLU'
|
||||
norm: 'LayerNorm'
|
||||
encoder:
|
||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
|
||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, symlog_inputs: True}
|
||||
decoder:
|
||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse}
|
||||
value_head: 'symlog_disc'
|
||||
reward_head: 'symlog_disc'
|
||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: True, cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse, outscale: 1.0}
|
||||
actor:
|
||||
{layers: 2, dist: 'normal', entropy: 3e-4, unimix_ratio: 0.01, min_std: 0.1, max_std: 1.0, temp: 0.1, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 1.0}
|
||||
critic:
|
||||
{layers: 2, dist: 'symlog_disc', slow_target: True, slow_target_update: 1, slow_target_fraction: 0.02, lr: 3e-5, eps: 1e-5, grad_clip: 100.0, outscale: 0.0}
|
||||
reward_head:
|
||||
{layers: 2, dist: 'symlog_disc', scale: 1.0, outscale: 0.0}
|
||||
cont_head:
|
||||
{layers: 2, scale: 1.0, outscale: 1.0}
|
||||
dyn_scale: 0.5
|
||||
rep_scale: 0.1
|
||||
kl_free: 1.0
|
||||
cont_scale: 1.0
|
||||
reward_scale: 1.0
|
||||
weight_decay: 0.0
|
||||
unimix_ratio: 0.01
|
||||
action_unimix_ratio: 0.01
|
||||
initial: 'learned'
|
||||
|
||||
# Training
|
||||
@ -77,15 +76,7 @@ defaults:
|
||||
model_lr: 1e-4
|
||||
opt_eps: 1e-8
|
||||
grad_clip: 1000
|
||||
value_lr: 3e-5
|
||||
actor_lr: 3e-5
|
||||
ac_opt_eps: 1e-5
|
||||
value_grad_clip: 100
|
||||
actor_grad_clip: 100
|
||||
dataset_size: 1000000
|
||||
slow_value_target: True
|
||||
slow_target_update: 1
|
||||
slow_target_fraction: 0.02
|
||||
opt: 'adam'
|
||||
|
||||
# Behavior.
|
||||
@ -95,18 +86,10 @@ defaults:
|
||||
imag_gradient: 'dynamics'
|
||||
imag_gradient_mix: 0.0
|
||||
imag_sample: True
|
||||
actor_dist: 'normal'
|
||||
actor_entropy: 3e-4
|
||||
actor_state_entropy: 0.0
|
||||
actor_init_std: 1.0
|
||||
actor_min_std: 0.1
|
||||
actor_max_std: 1.0
|
||||
actor_temp: 0.1
|
||||
expl_amount: 0.0
|
||||
expl_amount: 0
|
||||
eval_state_mean: False
|
||||
collect_dyn_sample: True
|
||||
behavior_stop_grad: True
|
||||
value_decay: 0.0
|
||||
future_entropy: False
|
||||
|
||||
# Exploration
|
||||
@ -150,13 +133,12 @@ crafter:
|
||||
dyn_hidden: 1024
|
||||
dyn_deter: 4096
|
||||
units: 1024
|
||||
reward_layers: 5
|
||||
cont_layers: 5
|
||||
value_layers: 5
|
||||
actor_layers: 5
|
||||
encoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
||||
decoder: {mlp_keys: '$^', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
||||
actor_dist: 'onehot'
|
||||
actor: {layers: 5, dist: 'onehot'}
|
||||
value: {layers: 5}
|
||||
reward_head: {layers: 5}
|
||||
cont_head: {layers: 5}
|
||||
imag_gradient: 'reinforce'
|
||||
|
||||
atari100k:
|
||||
@ -166,7 +148,7 @@ atari100k:
|
||||
train_ratio: 1024
|
||||
video_pred_log: true
|
||||
eval_episode_num: 100
|
||||
actor_dist: 'onehot'
|
||||
actor: {dist: 'onehot'}
|
||||
imag_gradient: 'reinforce'
|
||||
stickey: False
|
||||
lives: unused
|
||||
@ -189,13 +171,12 @@ minecraft:
|
||||
dyn_hidden: 1024
|
||||
dyn_deter: 4096
|
||||
units: 1024
|
||||
reward_layers: 5
|
||||
cont_layers: 5
|
||||
value_layers: 5
|
||||
actor_layers: 5
|
||||
encoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath|reward', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
||||
encoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath|obs_reward', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
||||
decoder: {mlp_keys: 'inventory|inventory_max|equipped|health|hunger|breath', cnn_keys: 'image', cnn_depth: 96, mlp_layers: 5, mlp_units: 1024}
|
||||
actor_dist: 'onehot'
|
||||
actor: {layers: 5, dist: 'onehot'}
|
||||
value: {layers: 5}
|
||||
reward_head: {layers: 5}
|
||||
cont_head: {layers: 5}
|
||||
imag_gradient: 'reinforce'
|
||||
break_speed: 100.0
|
||||
time_limit: 36000
|
||||
@ -203,7 +184,7 @@ minecraft:
|
||||
memorymaze:
|
||||
steps: 1e8
|
||||
action_repeat: 2
|
||||
actor_dist: 'onehot'
|
||||
actor: {dist: 'onehot'}
|
||||
imag_gradient: 'reinforce'
|
||||
task: 'memorymaze_9x9'
|
||||
|
||||
|
@ -110,7 +110,7 @@ class Dreamer(nn.Module):
|
||||
logprob = actor.log_prob(action)
|
||||
latent = {k: v.detach() for k, v in latent.items()}
|
||||
action = action.detach()
|
||||
if self._config.actor_dist == "onehot_gumble":
|
||||
if self._config.actor["dist"] == "onehot_gumble":
|
||||
action = torch.one_hot(
|
||||
torch.argmax(action, dim=-1), self._config.num_actions
|
||||
)
|
||||
@ -123,7 +123,7 @@ class Dreamer(nn.Module):
|
||||
amount = self._config.expl_amount if training else self._config.eval_noise
|
||||
if amount == 0:
|
||||
return action
|
||||
if "onehot" in self._config.actor_dist:
|
||||
if "onehot" in self._config.actor["dist"]:
|
||||
probs = amount / self._config.num_actions + (1 - amount) * action
|
||||
return tools.OneHotDist(probs=probs).sample()
|
||||
else:
|
||||
|
@ -14,7 +14,7 @@ class Random(nn.Module):
|
||||
self._act_space = act_space
|
||||
|
||||
def actor(self, feat):
|
||||
if self._config.actor_dist == "onehot":
|
||||
if self._config.actor["dist"] == "onehot":
|
||||
return tools.OneHotDist(
|
||||
torch.zeros(self._config.num_actions)
|
||||
.repeat(self._config.envs, 1)
|
||||
|
145
models.py
145
models.py
@ -67,39 +67,29 @@ class WorldModel(nn.Module):
|
||||
self.heads["decoder"] = networks.MultiDecoder(
|
||||
feat_size, shapes, **config.decoder
|
||||
)
|
||||
if config.reward_head == "symlog_disc":
|
||||
self.heads["reward"] = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
(255,),
|
||||
config.reward_layers,
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
dist=config.reward_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
else:
|
||||
self.heads["reward"] = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
[],
|
||||
config.reward_layers,
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
dist=config.reward_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
self.heads["reward"] = networks.MLP(
|
||||
feat_size,
|
||||
(255,) if config.reward_head["dist"] == "symlog_disc" else (),
|
||||
config.reward_head["layers"],
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
dist=config.reward_head["dist"],
|
||||
outscale=config.reward_head["outscale"],
|
||||
device=config.device,
|
||||
name="Reward",
|
||||
)
|
||||
self.heads["cont"] = networks.MLP(
|
||||
feat_size, # pytorch version
|
||||
[],
|
||||
config.cont_layers,
|
||||
feat_size,
|
||||
(),
|
||||
config.cont_head["layers"],
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
dist="binary",
|
||||
outscale=config.cont_head["outscale"],
|
||||
device=config.device,
|
||||
name="Cont",
|
||||
)
|
||||
for name in config.grad_heads:
|
||||
assert name in self.heads, name
|
||||
@ -113,7 +103,14 @@ class WorldModel(nn.Module):
|
||||
opt=config.opt,
|
||||
use_amp=self._use_amp,
|
||||
)
|
||||
self._scales = dict(reward=config.reward_scale, cont=config.cont_scale)
|
||||
print(
|
||||
f"Optimizer model_opt has {sum(param.numel() for param in self.parameters())} variables."
|
||||
)
|
||||
self._scales = dict(
|
||||
reward=config.reward_head["scale"],
|
||||
cont=config.cont_head["scale"],
|
||||
image=1.0,
|
||||
)
|
||||
|
||||
def _train(self, data):
|
||||
# action (batch_size, batch_length, act_dim)
|
||||
@ -134,6 +131,7 @@ class WorldModel(nn.Module):
|
||||
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
|
||||
post, prior, kl_free, dyn_scale, rep_scale
|
||||
)
|
||||
assert kl_loss.shape == embed.shape[:2], kl_loss.shape
|
||||
preds = {}
|
||||
for name, head in self.heads.items():
|
||||
grad_head = name in self._config.grad_heads
|
||||
@ -226,65 +224,60 @@ class ImagBehavior(nn.Module):
|
||||
feat_size = config.dyn_stoch * config.dyn_discrete + config.dyn_deter
|
||||
else:
|
||||
feat_size = config.dyn_stoch + config.dyn_deter
|
||||
self.actor = networks.ActionHead(
|
||||
self.actor = networks.MLP(
|
||||
feat_size,
|
||||
config.num_actions,
|
||||
config.actor_layers,
|
||||
(config.num_actions,),
|
||||
config.actor["layers"],
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
config.actor_dist,
|
||||
config.actor_init_std,
|
||||
config.actor_min_std,
|
||||
config.actor_max_std,
|
||||
config.actor_temp,
|
||||
config.actor["dist"],
|
||||
"learned",
|
||||
config.actor["min_std"],
|
||||
config.actor["max_std"],
|
||||
config.actor["temp"],
|
||||
unimix_ratio=config.actor["unimix_ratio"],
|
||||
outscale=1.0,
|
||||
unimix_ratio=config.action_unimix_ratio,
|
||||
name="Actor",
|
||||
)
|
||||
if config.value_head == "symlog_disc":
|
||||
self.value = networks.MLP(
|
||||
feat_size,
|
||||
(255,),
|
||||
config.value_layers,
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
config.value_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
else:
|
||||
self.value = networks.MLP(
|
||||
feat_size,
|
||||
[],
|
||||
config.value_layers,
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
config.value_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
if config.slow_value_target:
|
||||
self.value = networks.MLP(
|
||||
feat_size,
|
||||
(255,) if config.critic["dist"] == "symlog_disc" else (),
|
||||
config.critic["layers"],
|
||||
config.units,
|
||||
config.act,
|
||||
config.norm,
|
||||
config.critic["dist"],
|
||||
outscale=config.critic["outscale"],
|
||||
device=config.device,
|
||||
name="Value",
|
||||
)
|
||||
if config.critic["slow_target"]:
|
||||
self._slow_value = copy.deepcopy(self.value)
|
||||
self._updates = 0
|
||||
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
|
||||
self._actor_opt = tools.Optimizer(
|
||||
"actor",
|
||||
self.actor.parameters(),
|
||||
config.actor_lr,
|
||||
config.ac_opt_eps,
|
||||
config.actor_grad_clip,
|
||||
config.actor["lr"],
|
||||
config.actor["eps"],
|
||||
config.actor["grad_clip"],
|
||||
**kw,
|
||||
)
|
||||
print(
|
||||
f"Optimizer actor_opt has {sum(param.numel() for param in self.actor.parameters())} variables."
|
||||
)
|
||||
self._value_opt = tools.Optimizer(
|
||||
"value",
|
||||
self.value.parameters(),
|
||||
config.value_lr,
|
||||
config.ac_opt_eps,
|
||||
config.value_grad_clip,
|
||||
config.critic["lr"],
|
||||
config.critic["eps"],
|
||||
config.critic["grad_clip"],
|
||||
**kw,
|
||||
)
|
||||
print(
|
||||
f"Optimizer value_opt has {sum(param.numel() for param in self.value.parameters())} variables."
|
||||
)
|
||||
if self._config.reward_EMA:
|
||||
self.reward_ema = RewardEMA(device=self._config.device)
|
||||
|
||||
@ -335,19 +328,15 @@ class ImagBehavior(nn.Module):
|
||||
# (time, batch, 1), (time, batch, 1) -> (time, batch)
|
||||
value_loss = -value.log_prob(target.detach())
|
||||
slow_target = self._slow_value(value_input[:-1].detach())
|
||||
if self._config.slow_value_target:
|
||||
value_loss = value_loss - value.log_prob(
|
||||
slow_target.mode().detach()
|
||||
)
|
||||
if self._config.value_decay:
|
||||
value_loss += self._config.value_decay * value.mode()
|
||||
if self._config.critic["slow_target"]:
|
||||
value_loss -= value.log_prob(slow_target.mode().detach())
|
||||
# (time, batch, 1), (time, batch, 1) -> (1,)
|
||||
value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])
|
||||
|
||||
metrics.update(tools.tensorstats(value.mode(), "value"))
|
||||
metrics.update(tools.tensorstats(target, "target"))
|
||||
metrics.update(tools.tensorstats(reward, "imag_reward"))
|
||||
if self._config.actor_dist in ["onehot"]:
|
||||
if self._config.actor["dist"] in ["onehot"]:
|
||||
metrics.update(
|
||||
tools.tensorstats(
|
||||
torch.argmax(imag_action, dim=-1).float(), "imag_action"
|
||||
@ -466,9 +455,9 @@ class ImagBehavior(nn.Module):
|
||||
return actor_loss, metrics
|
||||
|
||||
def _update_slow_target(self):
|
||||
if self._config.slow_value_target:
|
||||
if self._updates % self._config.slow_target_update == 0:
|
||||
mix = self._config.slow_target_fraction
|
||||
if self._config.critic["slow_target"]:
|
||||
if self._updates % self._config.critic["slow_target_update"] == 0:
|
||||
mix = self._config.critic["slow_target_fraction"]
|
||||
for s, d in zip(self.value.parameters(), self._slow_value.parameters()):
|
||||
d.data = mix * s.data + (1 - mix) * d.data
|
||||
self._updates += 1
|
||||
|
150
networks.py
150
networks.py
@ -632,9 +632,14 @@ class MLP(nn.Module):
|
||||
norm="LayerNorm",
|
||||
dist="normal",
|
||||
std=1.0,
|
||||
min_std=0.1,
|
||||
max_std=1.0,
|
||||
temp=0.1,
|
||||
unimix_ratio=0.01,
|
||||
outscale=1.0,
|
||||
symlog_inputs=False,
|
||||
device="cuda",
|
||||
name="NoName",
|
||||
):
|
||||
super(MLP, self).__init__()
|
||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||
@ -647,15 +652,20 @@ class MLP(nn.Module):
|
||||
self._std = std
|
||||
self._symlog_inputs = symlog_inputs
|
||||
self._device = device
|
||||
self._min_std = min_std
|
||||
self._max_std = max_std
|
||||
self._temp = temp
|
||||
self._unimix_ratio = unimix_ratio
|
||||
|
||||
layers = []
|
||||
self.layers = nn.Sequential()
|
||||
for index in range(self._layers):
|
||||
layers.append(nn.Linear(inp_dim, units, bias=False))
|
||||
layers.append(norm(units, eps=1e-03))
|
||||
layers.append(act())
|
||||
self.layers.add_module(
|
||||
f"{name}_linear{index}", nn.Linear(inp_dim, units, bias=False)
|
||||
)
|
||||
self.layers.add_module(f"{name}_norm{index}", norm(units, eps=1e-03))
|
||||
self.layers.add_module(f"{name}_act{index}", act())
|
||||
if index == 0:
|
||||
inp_dim = units
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.layers.apply(tools.weight_init)
|
||||
|
||||
if isinstance(self._shape, dict):
|
||||
@ -664,6 +674,7 @@ class MLP(nn.Module):
|
||||
self.mean_layer[name] = nn.Linear(inp_dim, np.prod(shape))
|
||||
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
||||
if self._std == "learned":
|
||||
assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist
|
||||
self.std_layer = nn.ModuleDict()
|
||||
for name, shape in self._shape.items():
|
||||
self.std_layer[name] = nn.Linear(inp_dim, np.prod(shape))
|
||||
@ -672,6 +683,7 @@ class MLP(nn.Module):
|
||||
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
|
||||
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
||||
if self._std == "learned":
|
||||
assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist
|
||||
self.std_layer = nn.Linear(units, np.prod(self._shape))
|
||||
self.std_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
@ -680,6 +692,7 @@ class MLP(nn.Module):
|
||||
if self._symlog_inputs:
|
||||
x = tools.symlog(x)
|
||||
out = self.layers(x)
|
||||
# Used for encoder output
|
||||
if self._shape is None:
|
||||
return out
|
||||
if isinstance(self._shape, dict):
|
||||
@ -701,98 +714,9 @@ class MLP(nn.Module):
|
||||
return self.dist(self._dist, mean, std, self._shape)
|
||||
|
||||
def dist(self, dist, mean, std, shape):
|
||||
if dist == "normal":
|
||||
return tools.ContDist(
|
||||
torchd.independent.Independent(
|
||||
torchd.normal.Normal(mean, std), len(shape)
|
||||
)
|
||||
)
|
||||
if dist == "huber":
|
||||
return tools.ContDist(
|
||||
torchd.independent.Independent(
|
||||
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
|
||||
)
|
||||
)
|
||||
if dist == "binary":
|
||||
return tools.Bernoulli(
|
||||
torchd.independent.Independent(
|
||||
torchd.bernoulli.Bernoulli(logits=mean), len(shape)
|
||||
)
|
||||
)
|
||||
if dist == "symlog_disc":
|
||||
return tools.DiscDist(logits=mean, device=self._device)
|
||||
if dist == "symlog_mse":
|
||||
return tools.SymlogDist(mean)
|
||||
raise NotImplementedError(dist)
|
||||
|
||||
|
||||
class ActionHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
inp_dim,
|
||||
size,
|
||||
layers,
|
||||
units,
|
||||
act=nn.ELU,
|
||||
norm=nn.LayerNorm,
|
||||
dist="trunc_normal",
|
||||
init_std=0.0,
|
||||
min_std=0.1,
|
||||
max_std=1.0,
|
||||
temp=0.1,
|
||||
outscale=1.0,
|
||||
unimix_ratio=0.01,
|
||||
):
|
||||
super(ActionHead, self).__init__()
|
||||
self._size = size
|
||||
self._layers = layers
|
||||
self._units = units
|
||||
self._dist = dist
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._min_std = min_std
|
||||
self._max_std = max_std
|
||||
self._init_std = init_std
|
||||
self._unimix_ratio = unimix_ratio
|
||||
self._temp = temp() if callable(temp) else temp
|
||||
|
||||
pre_layers = []
|
||||
for index in range(self._layers):
|
||||
pre_layers.append(nn.Linear(inp_dim, self._units, bias=False))
|
||||
pre_layers.append(norm(self._units, eps=1e-03))
|
||||
pre_layers.append(act())
|
||||
if index == 0:
|
||||
inp_dim = self._units
|
||||
self._pre_layers = nn.Sequential(*pre_layers)
|
||||
self._pre_layers.apply(tools.weight_init)
|
||||
|
||||
if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]:
|
||||
self._dist_layer = nn.Linear(self._units, 2 * self._size)
|
||||
self._dist_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]:
|
||||
self._dist_layer = nn.Linear(self._units, self._size)
|
||||
self._dist_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
def forward(self, features, dtype=None):
|
||||
x = features
|
||||
x = self._pre_layers(x)
|
||||
if self._dist == "tanh_normal":
|
||||
x = self._dist_layer(x)
|
||||
mean, std = torch.split(x, 2, -1)
|
||||
mean = torch.tanh(mean)
|
||||
std = F.softplus(std + self._init_std) + self._min_std
|
||||
dist = torchd.normal.Normal(mean, std)
|
||||
dist = torchd.transformed_distribution.TransformedDistribution(
|
||||
dist, tools.TanhBijector()
|
||||
)
|
||||
dist = torchd.independent.Independent(dist, 1)
|
||||
dist = tools.SampleDist(dist)
|
||||
elif self._dist == "tanh_normal_5":
|
||||
x = self._dist_layer(x)
|
||||
mean, std = torch.split(x, 2, -1)
|
||||
mean = 5 * torch.tanh(mean / 5)
|
||||
std = F.softplus(std + 5) + 5
|
||||
std = F.softplus(std) + self._min_std
|
||||
dist = torchd.normal.Normal(mean, std)
|
||||
dist = torchd.transformed_distribution.TransformedDistribution(
|
||||
dist, tools.TanhBijector()
|
||||
@ -800,33 +724,41 @@ class ActionHead(nn.Module):
|
||||
dist = torchd.independent.Independent(dist, 1)
|
||||
dist = tools.SampleDist(dist)
|
||||
elif self._dist == "normal":
|
||||
x = self._dist_layer(x)
|
||||
mean, std = torch.split(x, [self._size] * 2, -1)
|
||||
std = (self._max_std - self._min_std) * torch.sigmoid(
|
||||
std + 2.0
|
||||
) + self._min_std
|
||||
dist = torchd.normal.Normal(torch.tanh(mean), std)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
elif self._dist == "normal_1":
|
||||
mean = self._dist_layer(x)
|
||||
dist = torchd.normal.Normal(mean, 1)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1), absmax=1.0)
|
||||
elif self._dist == "normal_std_fixed":
|
||||
dist = torchd.normal.Normal(mean, self._std)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
elif self._dist == "trunc_normal":
|
||||
x = self._dist_layer(x)
|
||||
mean, std = torch.split(x, [self._size] * 2, -1)
|
||||
mean = torch.tanh(mean)
|
||||
std = 2 * torch.sigmoid(std / 2) + self._min_std
|
||||
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
elif self._dist == "onehot":
|
||||
x = self._dist_layer(x)
|
||||
dist = tools.OneHotDist(x, unimix_ratio=self._unimix_ratio)
|
||||
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
|
||||
elif self._dist == "onehot_gumble":
|
||||
x = self._dist_layer(x)
|
||||
temp = self._temp
|
||||
dist = tools.ContDist(torchd.gumbel.Gumbel(x, 1 / temp))
|
||||
dist = tools.ContDist(torchd.gumbel.Gumbel(mean, 1 / self._temp))
|
||||
elif dist == "huber":
|
||||
dist = tools.ContDist(
|
||||
torchd.independent.Independent(
|
||||
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
|
||||
)
|
||||
)
|
||||
elif dist == "binary":
|
||||
dist = tools.Bernoulli(
|
||||
torchd.independent.Independent(
|
||||
torchd.bernoulli.Bernoulli(logits=mean), len(shape)
|
||||
)
|
||||
)
|
||||
elif dist == "symlog_disc":
|
||||
dist = tools.DiscDist(logits=mean, device=self._device)
|
||||
elif dist == "symlog_mse":
|
||||
dist = tools.SymlogDist(mean)
|
||||
else:
|
||||
raise NotImplementedError(self._dist)
|
||||
raise NotImplementedError(dist)
|
||||
return dist
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user