erased unnecessary lines
This commit is contained in:
parent
6c861ca7cb
commit
f7c505579c
@ -55,7 +55,9 @@ class Dreamer(nn.Module):
|
|||||||
self._task_behavior = models.ImagBehavior(
|
self._task_behavior = models.ImagBehavior(
|
||||||
config, self._wm, config.behavior_stop_grad
|
config, self._wm, config.behavior_stop_grad
|
||||||
)
|
)
|
||||||
if config.compile and os.name != 'nt': # compilation is not supported on windows
|
if (
|
||||||
|
config.compile and os.name != "nt"
|
||||||
|
): # compilation is not supported on windows
|
||||||
self._wm = torch.compile(self._wm)
|
self._wm = torch.compile(self._wm)
|
||||||
self._task_behavior = torch.compile(self._task_behavior)
|
self._task_behavior = torch.compile(self._task_behavior)
|
||||||
reward = lambda f, s, a: self._wm.heads["reward"](f).mean()
|
reward = lambda f, s, a: self._wm.heads["reward"](f).mean()
|
||||||
@ -156,7 +158,6 @@ class Dreamer(nn.Module):
|
|||||||
post, context, mets = self._wm._train(data)
|
post, context, mets = self._wm._train(data)
|
||||||
metrics.update(mets)
|
metrics.update(mets)
|
||||||
start = post
|
start = post
|
||||||
# start['deter'] (16, 64, 512)
|
|
||||||
reward = lambda f, s, a: self._wm.heads["reward"](
|
reward = lambda f, s, a: self._wm.heads["reward"](
|
||||||
self._wm.dynamics.get_feat(s)
|
self._wm.dynamics.get_feat(s)
|
||||||
).mode()
|
).mode()
|
||||||
|
@ -58,7 +58,9 @@ class Plan2Explore(nn.Module):
|
|||||||
"feat": config.dyn_stoch + config.dyn_deter,
|
"feat": config.dyn_stoch + config.dyn_deter,
|
||||||
}[self._config.disag_target]
|
}[self._config.disag_target]
|
||||||
kw = dict(
|
kw = dict(
|
||||||
inp_dim=feat_size + config.num_actions if config.disag_action_cond else 0, # pytorch version
|
inp_dim=feat_size + config.num_actions
|
||||||
|
if config.disag_action_cond
|
||||||
|
else 0, # pytorch version
|
||||||
shape=size,
|
shape=size,
|
||||||
layers=config.disag_layers,
|
layers=config.disag_layers,
|
||||||
units=config.disag_units,
|
units=config.disag_units,
|
||||||
@ -93,7 +95,9 @@ class Plan2Explore(nn.Module):
|
|||||||
}[self._config.disag_target]
|
}[self._config.disag_target]
|
||||||
inputs = context["feat"]
|
inputs = context["feat"]
|
||||||
if self._config.disag_action_cond:
|
if self._config.disag_action_cond:
|
||||||
inputs = torch.concat([inputs, torch.Tensor(data["action"]).to(self._config.device)], -1)
|
inputs = torch.concat(
|
||||||
|
[inputs, torch.Tensor(data["action"]).to(self._config.device)], -1
|
||||||
|
)
|
||||||
metrics.update(self._train_ensemble(inputs, target))
|
metrics.update(self._train_ensemble(inputs, target))
|
||||||
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
|
||||||
return None, metrics
|
return None, metrics
|
||||||
|
@ -399,9 +399,6 @@ class ImagBehavior(nn.Module):
|
|||||||
if self._config.future_entropy and self._config.actor_state_entropy() > 0:
|
if self._config.future_entropy and self._config.actor_state_entropy() > 0:
|
||||||
reward += self._config.actor_state_entropy() * state_ent
|
reward += self._config.actor_state_entropy() * state_ent
|
||||||
value = self.value(imag_feat).mode()
|
value = self.value(imag_feat).mode()
|
||||||
# value(15, 960, ch)
|
|
||||||
# action(15, 960, ch)
|
|
||||||
# discount(15, 960, ch)
|
|
||||||
target = tools.lambda_return(
|
target = tools.lambda_return(
|
||||||
reward[:-1],
|
reward[:-1],
|
||||||
value[:-1],
|
value[:-1],
|
||||||
|
@ -215,7 +215,9 @@ class RSSM(nn.Module):
|
|||||||
is_first,
|
is_first,
|
||||||
is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
|
is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
|
||||||
)
|
)
|
||||||
prev_state[key] = val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
prev_state[key] = (
|
||||||
|
val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
||||||
|
)
|
||||||
|
|
||||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
prior = self.img_step(prev_state, prev_action, None, sample)
|
||||||
if self._shared:
|
if self._shared:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user