changed the discount head to predict terminal
This commit is contained in:
parent
16151efb3c
commit
628b856c63
14
configs.yaml
14
configs.yaml
@ -42,10 +42,10 @@ defaults:
|
|||||||
dyn_std_act: 'sigmoid2'
|
dyn_std_act: 'sigmoid2'
|
||||||
dyn_min_std: 0.1
|
dyn_min_std: 0.1
|
||||||
dyn_temp_post: True
|
dyn_temp_post: True
|
||||||
grad_heads: ['image', 'reward', 'discount']
|
grad_heads: ['image', 'reward', 'cont']
|
||||||
units: 512
|
units: 512
|
||||||
reward_layers: 2
|
reward_layers: 2
|
||||||
discount_layers: 2
|
cont_layers: 2
|
||||||
value_layers: 2
|
value_layers: 2
|
||||||
actor_layers: 2
|
actor_layers: 2
|
||||||
act: 'SiLU'
|
act: 'SiLU'
|
||||||
@ -55,12 +55,10 @@ defaults:
|
|||||||
decoder_kernels: [4, 4, 4, 4]
|
decoder_kernels: [4, 4, 4, 4]
|
||||||
value_head: 'twohot_symlog'
|
value_head: 'twohot_symlog'
|
||||||
reward_head: 'twohot_symlog'
|
reward_head: 'twohot_symlog'
|
||||||
kl_lscale: '0.1'
|
dyn_scale: '0.5'
|
||||||
kl_rscale: '0.5'
|
rep_scale: '0.1'
|
||||||
kl_free: '1.0'
|
kl_free: '1.0'
|
||||||
kl_forward: False
|
cont_scale: 1.0
|
||||||
pred_discount: True
|
|
||||||
discount_scale: 1.0
|
|
||||||
reward_scale: 1.0
|
reward_scale: 1.0
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
unimix_ratio: 0.01
|
unimix_ratio: 0.01
|
||||||
@ -80,7 +78,7 @@ defaults:
|
|||||||
value_grad_clip: 100
|
value_grad_clip: 100
|
||||||
actor_grad_clip: 100
|
actor_grad_clip: 100
|
||||||
dataset_size: 1000000
|
dataset_size: 1000000
|
||||||
oversample_ends: False
|
oversample_ends: True
|
||||||
slow_value_target: True
|
slow_value_target: True
|
||||||
slow_target_update: 1
|
slow_target_update: 1
|
||||||
slow_target_fraction: 0.02
|
slow_target_fraction: 0.02
|
||||||
|
@ -155,16 +155,11 @@ class Dreamer(nn.Module):
|
|||||||
metrics.update(mets)
|
metrics.update(mets)
|
||||||
start = post
|
start = post
|
||||||
# start['deter'] (16, 64, 512)
|
# start['deter'] (16, 64, 512)
|
||||||
if self._config.pred_discount: # Last step could be terminal.
|
|
||||||
start = {k: v[:, :-1] for k, v in post.items()}
|
|
||||||
context = {k: v[:, :-1] for k, v in context.items()}
|
|
||||||
reward = lambda f, s, a: self._wm.heads["reward"](
|
reward = lambda f, s, a: self._wm.heads["reward"](
|
||||||
self._wm.dynamics.get_feat(s)
|
self._wm.dynamics.get_feat(s)
|
||||||
).mode()
|
).mode()
|
||||||
metrics.update(self._task_behavior._train(start, reward)[-1])
|
metrics.update(self._task_behavior._train(start, reward)[-1])
|
||||||
if self._config.expl_behavior != "greedy":
|
if self._config.expl_behavior != "greedy":
|
||||||
if self._config.pred_discount:
|
|
||||||
data = {k: v[:, :-1] for k, v in data.items()}
|
|
||||||
mets = self._expl_behavior.train(start, context, data)[-1]
|
mets = self._expl_behavior.train(start, context, data)[-1]
|
||||||
metrics.update({"expl_" + key: value for key, value in mets.items()})
|
metrics.update({"expl_" + key: value for key, value in mets.items()})
|
||||||
for name, value in metrics.items():
|
for name, value in metrics.items():
|
||||||
|
39
models.py
39
models.py
@ -107,11 +107,10 @@ class WorldModel(nn.Module):
|
|||||||
dist=config.reward_head,
|
dist=config.reward_head,
|
||||||
outscale=0.0,
|
outscale=0.0,
|
||||||
)
|
)
|
||||||
if config.pred_discount:
|
self.heads["cont"] = networks.DenseHead(
|
||||||
self.heads["discount"] = networks.DenseHead(
|
|
||||||
feat_size, # pytorch version
|
feat_size, # pytorch version
|
||||||
[],
|
[],
|
||||||
config.discount_layers,
|
config.cont_layers,
|
||||||
config.units,
|
config.units,
|
||||||
config.act,
|
config.act,
|
||||||
config.norm,
|
config.norm,
|
||||||
@ -129,7 +128,7 @@ class WorldModel(nn.Module):
|
|||||||
opt=config.opt,
|
opt=config.opt,
|
||||||
use_amp=self._use_amp,
|
use_amp=self._use_amp,
|
||||||
)
|
)
|
||||||
self._scales = dict(reward=config.reward_scale, discount=config.discount_scale)
|
self._scales = dict(reward=config.reward_scale, cont=config.cont_scale)
|
||||||
|
|
||||||
def _train(self, data):
|
def _train(self, data):
|
||||||
# action (batch_size, batch_length, act_dim)
|
# action (batch_size, batch_length, act_dim)
|
||||||
@ -143,10 +142,10 @@ class WorldModel(nn.Module):
|
|||||||
embed = self.encoder(data)
|
embed = self.encoder(data)
|
||||||
post, prior = self.dynamics.observe(embed, data["action"])
|
post, prior = self.dynamics.observe(embed, data["action"])
|
||||||
kl_free = tools.schedule(self._config.kl_free, self._step)
|
kl_free = tools.schedule(self._config.kl_free, self._step)
|
||||||
kl_lscale = tools.schedule(self._config.kl_lscale, self._step)
|
dyn_scale = tools.schedule(self._config.dyn_scale, self._step)
|
||||||
kl_rscale = tools.schedule(self._config.kl_rscale, self._step)
|
rep_scale = tools.schedule(self._config.rep_scale, self._step)
|
||||||
kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss(
|
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
|
||||||
post, prior, self._config.kl_forward, kl_free, kl_lscale, kl_rscale
|
post, prior, kl_free, dyn_scale, rep_scale
|
||||||
)
|
)
|
||||||
losses = {}
|
losses = {}
|
||||||
likes = {}
|
likes = {}
|
||||||
@ -163,10 +162,10 @@ class WorldModel(nn.Module):
|
|||||||
|
|
||||||
metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
|
metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
|
||||||
metrics["kl_free"] = kl_free
|
metrics["kl_free"] = kl_free
|
||||||
metrics["kl_lscale"] = kl_lscale
|
metrics["dyn_scale"] = dyn_scale
|
||||||
metrics["kl_rscale"] = kl_rscale
|
metrics["rep_scale"] = rep_scale
|
||||||
metrics["loss_lhs"] = to_np(loss_lhs)
|
metrics["dyn_loss"] = to_np(dyn_loss)
|
||||||
metrics["loss_rhs"] = to_np(loss_rhs)
|
metrics["rep_loss"] = to_np(rep_loss)
|
||||||
metrics["kl"] = to_np(torch.mean(kl_value))
|
metrics["kl"] = to_np(torch.mean(kl_value))
|
||||||
with torch.cuda.amp.autocast(self._use_amp):
|
with torch.cuda.amp.autocast(self._use_amp):
|
||||||
metrics["prior_ent"] = to_np(
|
metrics["prior_ent"] = to_np(
|
||||||
@ -193,6 +192,11 @@ class WorldModel(nn.Module):
|
|||||||
obs["discount"] *= self._config.discount
|
obs["discount"] *= self._config.discount
|
||||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||||
obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
|
obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
|
||||||
|
if "is_terminal" in obs:
|
||||||
|
# this label is necessary to train cont_head
|
||||||
|
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
raise ValueError('"is_terminal" was not found in observation.')
|
||||||
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
|
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
@ -347,6 +351,13 @@ class ImagBehavior(nn.Module):
|
|||||||
metrics.update(tools.tensorstats(value.mode(), "value"))
|
metrics.update(tools.tensorstats(value.mode(), "value"))
|
||||||
metrics.update(tools.tensorstats(target, "target"))
|
metrics.update(tools.tensorstats(target, "target"))
|
||||||
metrics.update(tools.tensorstats(reward, "imag_reward"))
|
metrics.update(tools.tensorstats(reward, "imag_reward"))
|
||||||
|
if self._config.actor_dist in ["onehot"]:
|
||||||
|
metrics.update(
|
||||||
|
tools.tensorstats(
|
||||||
|
torch.argmax(imag_action, dim=-1).float(), "imag_action"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
metrics.update(tools.tensorstats(imag_action, "imag_action"))
|
metrics.update(tools.tensorstats(imag_action, "imag_action"))
|
||||||
metrics["actor_ent"] = to_np(torch.mean(actor_ent))
|
metrics["actor_ent"] = to_np(torch.mean(actor_ent))
|
||||||
with tools.RequiresGrad(self):
|
with tools.RequiresGrad(self):
|
||||||
@ -390,9 +401,9 @@ class ImagBehavior(nn.Module):
|
|||||||
def _compute_target(
|
def _compute_target(
|
||||||
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
|
self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
|
||||||
):
|
):
|
||||||
if "discount" in self._world_model.heads:
|
if "cont" in self._world_model.heads:
|
||||||
inp = self._world_model.dynamics.get_feat(imag_state)
|
inp = self._world_model.dynamics.get_feat(imag_state)
|
||||||
discount = self._world_model.heads["discount"](inp).mean
|
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
|
||||||
else:
|
else:
|
||||||
discount = self._config.discount * torch.ones_like(reward)
|
discount = self._config.discount * torch.ones_like(reward)
|
||||||
if self._config.future_entropy and self._config.actor_entropy() > 0:
|
if self._config.future_entropy and self._config.actor_entropy() > 0:
|
||||||
|
26
networks.py
26
networks.py
@ -273,28 +273,24 @@ class RSSM(nn.Module):
|
|||||||
std = std + self._min_std
|
std = std + self._min_std
|
||||||
return {"mean": mean, "std": std}
|
return {"mean": mean, "std": std}
|
||||||
|
|
||||||
def kl_loss(self, post, prior, forward, free, lscale, rscale):
|
def kl_loss(self, post, prior, free, dyn_scale, rep_scale):
|
||||||
kld = torchd.kl.kl_divergence
|
kld = torchd.kl.kl_divergence
|
||||||
dist = lambda x: self.get_dist(x)
|
dist = lambda x: self.get_dist(x)
|
||||||
sg = lambda x: {k: v.detach() for k, v in x.items()}
|
sg = lambda x: {k: v.detach() for k, v in x.items()}
|
||||||
# forward == false -> (post, prior)
|
|
||||||
lhs, rhs = (prior, post) if forward else (post, prior)
|
|
||||||
|
|
||||||
# forward == false -> Lrep
|
rep_loss = value = kld(
|
||||||
value_lhs = value = kld(
|
dist(post) if self._discrete else dist(post)._dist,
|
||||||
dist(lhs) if self._discrete else dist(lhs)._dist,
|
dist(sg(prior)) if self._discrete else dist(sg(prior))._dist,
|
||||||
dist(sg(rhs)) if self._discrete else dist(sg(rhs))._dist,
|
|
||||||
)
|
)
|
||||||
# forward == false -> Ldyn
|
dyn_loss = kld(
|
||||||
value_rhs = kld(
|
dist(sg(post)) if self._discrete else dist(sg(post))._dist,
|
||||||
dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist,
|
dist(prior) if self._discrete else dist(prior)._dist,
|
||||||
dist(rhs) if self._discrete else dist(rhs)._dist,
|
|
||||||
)
|
)
|
||||||
loss_lhs = torch.clip(torch.mean(value_lhs), min=free)
|
rep_loss = torch.mean(torch.clip(rep_loss, min=free))
|
||||||
loss_rhs = torch.clip(torch.mean(value_rhs), min=free)
|
dyn_loss = torch.mean(torch.clip(dyn_loss, min=free))
|
||||||
loss = lscale * loss_lhs + rscale * loss_rhs
|
loss = dyn_scale * dyn_loss + rep_scale * rep_loss
|
||||||
|
|
||||||
return loss, value, loss_lhs, loss_rhs
|
return loss, value, dyn_loss, rep_loss
|
||||||
|
|
||||||
|
|
||||||
class ConvEncoder(nn.Module):
|
class ConvEncoder(nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user