removed scheduling function
This commit is contained in:
parent
2cdba230d8
commit
16635df3e4
10
configs.yaml
10
configs.yaml
@ -59,9 +59,9 @@ defaults:
|
|||||||
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse}
|
{mlp_keys: '$^', cnn_keys: 'image', act: 'SiLU', norm: 'LayerNorm', cnn_depth: 32, kernel_size: 4, minres: 4, mlp_layers: 2, mlp_units: 512, cnn_sigmoid: False, image_dist: mse, vector_dist: symlog_mse}
|
||||||
value_head: 'symlog_disc'
|
value_head: 'symlog_disc'
|
||||||
reward_head: 'symlog_disc'
|
reward_head: 'symlog_disc'
|
||||||
dyn_scale: '0.5'
|
dyn_scale: 0.5
|
||||||
rep_scale: '0.1'
|
rep_scale: 0.1
|
||||||
kl_free: '1.0'
|
kl_free: 1.0
|
||||||
cont_scale: 1.0
|
cont_scale: 1.0
|
||||||
reward_scale: 1.0
|
reward_scale: 1.0
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
@ -93,10 +93,10 @@ defaults:
|
|||||||
discount_lambda: 0.95
|
discount_lambda: 0.95
|
||||||
imag_horizon: 15
|
imag_horizon: 15
|
||||||
imag_gradient: 'dynamics'
|
imag_gradient: 'dynamics'
|
||||||
imag_gradient_mix: '0.0'
|
imag_gradient_mix: 0.0
|
||||||
imag_sample: True
|
imag_sample: True
|
||||||
actor_dist: 'normal'
|
actor_dist: 'normal'
|
||||||
actor_entropy: '3e-4'
|
actor_entropy: 3e-4
|
||||||
actor_state_entropy: 0.0
|
actor_state_entropy: 0.0
|
||||||
actor_init_std: 1.0
|
actor_init_std: 1.0
|
||||||
actor_min_std: 0.1
|
actor_min_std: 0.1
|
||||||
|
10
dreamer.py
10
dreamer.py
@ -40,16 +40,6 @@ class Dreamer(nn.Module):
|
|||||||
# this is update step
|
# this is update step
|
||||||
self._step = logger.step // config.action_repeat
|
self._step = logger.step // config.action_repeat
|
||||||
self._update_count = 0
|
self._update_count = 0
|
||||||
# Schedules.
|
|
||||||
config.actor_entropy = lambda x=config.actor_entropy: tools.schedule(
|
|
||||||
x, self._step
|
|
||||||
)
|
|
||||||
config.actor_state_entropy = (
|
|
||||||
lambda x=config.actor_state_entropy: tools.schedule(x, self._step)
|
|
||||||
)
|
|
||||||
config.imag_gradient_mix = lambda x=config.imag_gradient_mix: tools.schedule(
|
|
||||||
x, self._step
|
|
||||||
)
|
|
||||||
self._dataset = dataset
|
self._dataset = dataset
|
||||||
self._wm = models.WorldModel(obs_space, act_space, self._step, config)
|
self._wm = models.WorldModel(obs_space, act_space, self._step, config)
|
||||||
self._task_behavior = models.ImagBehavior(
|
self._task_behavior = models.ImagBehavior(
|
||||||
|
24
models.py
24
models.py
@ -128,9 +128,9 @@ class WorldModel(nn.Module):
|
|||||||
post, prior = self.dynamics.observe(
|
post, prior = self.dynamics.observe(
|
||||||
embed, data["action"], data["is_first"]
|
embed, data["action"], data["is_first"]
|
||||||
)
|
)
|
||||||
kl_free = tools.schedule(self._config.kl_free, self._step)
|
kl_free = self._config.kl_free
|
||||||
dyn_scale = tools.schedule(self._config.dyn_scale, self._step)
|
dyn_scale = self._config.dyn_scale
|
||||||
rep_scale = tools.schedule(self._config.rep_scale, self._step)
|
rep_scale = self._config.rep_scale
|
||||||
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
|
kl_loss, kl_value, dyn_loss, rep_loss = self.dynamics.kl_loss(
|
||||||
post, prior, kl_free, dyn_scale, rep_scale
|
post, prior, kl_free, dyn_scale, rep_scale
|
||||||
)
|
)
|
||||||
@ -393,10 +393,10 @@ class ImagBehavior(nn.Module):
|
|||||||
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
|
discount = self._config.discount * self._world_model.heads["cont"](inp).mean
|
||||||
else:
|
else:
|
||||||
discount = self._config.discount * torch.ones_like(reward)
|
discount = self._config.discount * torch.ones_like(reward)
|
||||||
if self._config.future_entropy and self._config.actor_entropy() > 0:
|
if self._config.future_entropy and self._config.actor_entropy > 0:
|
||||||
reward += self._config.actor_entropy() * actor_ent
|
reward += self._config.actor_entropy * actor_ent
|
||||||
if self._config.future_entropy and self._config.actor_state_entropy() > 0:
|
if self._config.future_entropy and self._config.actor_state_entropy > 0:
|
||||||
reward += self._config.actor_state_entropy() * state_ent
|
reward += self._config.actor_state_entropy * state_ent
|
||||||
value = self.value(imag_feat).mode()
|
value = self.value(imag_feat).mode()
|
||||||
target = tools.lambda_return(
|
target = tools.lambda_return(
|
||||||
reward[1:],
|
reward[1:],
|
||||||
@ -450,16 +450,16 @@ class ImagBehavior(nn.Module):
|
|||||||
policy.log_prob(imag_action)[:-1][:, :, None]
|
policy.log_prob(imag_action)[:-1][:, :, None]
|
||||||
* (target - self.value(imag_feat[:-1]).mode()).detach()
|
* (target - self.value(imag_feat[:-1]).mode()).detach()
|
||||||
)
|
)
|
||||||
mix = self._config.imag_gradient_mix()
|
mix = self._config.imag_gradient_mix
|
||||||
actor_target = mix * target + (1 - mix) * actor_target
|
actor_target = mix * target + (1 - mix) * actor_target
|
||||||
metrics["imag_gradient_mix"] = mix
|
metrics["imag_gradient_mix"] = mix
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(self._config.imag_gradient)
|
raise NotImplementedError(self._config.imag_gradient)
|
||||||
if not self._config.future_entropy and (self._config.actor_entropy() > 0):
|
if not self._config.future_entropy and self._config.actor_entropy > 0:
|
||||||
actor_entropy = self._config.actor_entropy() * actor_ent[:-1][:, :, None]
|
actor_entropy = self._config.actor_entropy * actor_ent[:-1][:, :, None]
|
||||||
actor_target += actor_entropy
|
actor_target += actor_entropy
|
||||||
if not self._config.future_entropy and (self._config.actor_state_entropy() > 0):
|
if not self._config.future_entropy and (self._config.actor_state_entropy > 0):
|
||||||
state_entropy = self._config.actor_state_entropy() * state_ent[:-1]
|
state_entropy = self._config.actor_state_entropy * state_ent[:-1]
|
||||||
actor_target += state_entropy
|
actor_target += state_entropy
|
||||||
metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy))
|
metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy))
|
||||||
actor_loss = -torch.mean(weights[:-1] * actor_target)
|
actor_loss = -torch.mean(weights[:-1] * actor_target)
|
||||||
|
27
tools.py
27
tools.py
@ -899,33 +899,6 @@ class Until:
|
|||||||
return step < self._until
|
return step < self._until
|
||||||
|
|
||||||
|
|
||||||
def schedule(string, step):
|
|
||||||
try:
|
|
||||||
return float(string)
|
|
||||||
except ValueError:
|
|
||||||
match = re.match(r"linear\((.+),(.+),(.+)\)", string)
|
|
||||||
if match:
|
|
||||||
initial, final, duration = [float(group) for group in match.groups()]
|
|
||||||
mix = torch.clip(torch.Tensor([step / duration]), 0, 1)[0]
|
|
||||||
return (1 - mix) * initial + mix * final
|
|
||||||
match = re.match(r"warmup\((.+),(.+)\)", string)
|
|
||||||
if match:
|
|
||||||
warmup, value = [float(group) for group in match.groups()]
|
|
||||||
scale = torch.clip(step / warmup, 0, 1)
|
|
||||||
return scale * value
|
|
||||||
match = re.match(r"exp\((.+),(.+),(.+)\)", string)
|
|
||||||
if match:
|
|
||||||
initial, final, halflife = [float(group) for group in match.groups()]
|
|
||||||
return (initial - final) * 0.5 ** (step / halflife) + final
|
|
||||||
match = re.match(r"horizon\((.+),(.+),(.+)\)", string)
|
|
||||||
if match:
|
|
||||||
initial, final, duration = [float(group) for group in match.groups()]
|
|
||||||
mix = torch.clip(step / duration, 0, 1)
|
|
||||||
horizon = (1 - mix) * initial + mix * final
|
|
||||||
return 1 - 1 / horizon
|
|
||||||
raise NotImplementedError(string)
|
|
||||||
|
|
||||||
|
|
||||||
def weight_init(m):
|
def weight_init(m):
|
||||||
if isinstance(m, nn.Linear):
|
if isinstance(m, nn.Linear):
|
||||||
in_num = m.in_features
|
in_num = m.in_features
|
||||||
|
Loading…
x
Reference in New Issue
Block a user